-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-798] Fix the dtype cast from non float32 in Gradient computation #12290
Changes from 6 commits
f90ffaa
f375af2
e0122e1
a7da018
1f15b4d
1677bd0
0f5639a
e239f15
dcc5f78
9c09a2e
324390f
b979ce5
b853440
5da3117
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# pylint: skip-file | ||
import mxnet as mx | ||
import numpy as np | ||
from common import models | ||
from mxnet import autograd | ||
from nose.tools import * | ||
|
||
def test_infer_multiout_op(): | ||
data = mx.nd.arange(16, dtype=np.float64).reshape((4, 4)) | ||
data.attach_grad() | ||
|
||
with autograd.record(): | ||
y = mx.nd.split(data, axis=0, num_outputs=2) | ||
y[0].backward() | ||
assert data.grad.dtype == np.float64 | ||
|
||
def test_infer_multiout_op2(): | ||
def test_func(a): | ||
q, l = mx.nd.linalg.gelqf(a) | ||
return mx.nd.sum(l) | ||
|
||
data32 = mx.nd.random.normal(shape=(2, 3), ctx=mx.cpu(), dtype=np.float32) | ||
data32.attach_grad() | ||
with autograd.record(): | ||
test32 = test_func(data32) | ||
test32.backward() | ||
|
||
data64 = mx.nd.Cast(data32, dtype=np.float64) | ||
data64.attach_grad() | ||
with autograd.record(): | ||
test64 = test_func(data64) | ||
test64.backward() | ||
assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you set rtol and atol to some bigger value than default here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why increase the rtol and atol if the unit test can pass with the default one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be flaky. you are comparing a float32 numpy to a float64 numpy and the atol and rtol defaults are small. |
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see /~https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not to test the functionality of the operator but a general type casting issue for all multioutput operators. I inclined to add it in the infer type tests but would like to hear more suggestions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed test to run nose runmodule |
||
if __name__ == '__main__': | ||
import nose | ||
nose.runmodule() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure about this? This affects all _zero ops, not just for the case you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, this is breaking some unit test (however, due to unittest of master branch is broken in MacOS, I wan't able to verify before checkin). I have changed the PR to WIP.