-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-798] Fix the dtype cast from non float32 in Gradient computation #12290
Conversation
@eric-haibin-lin @piiswrong @haojin2 I will appreciate your review. |
|
||
|
||
if __name__ == "__main__": | ||
test_infer_multiout_op() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see /~https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.
test64.backward() | ||
assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all()) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see /~https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not to test the functionality of the operator but a general type casting issue for all multioutput operators. I inclined to add it in the infer type tests but would like to hear more suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed test to run nose runmodule
Change to [WIP] to fix some platform dependent unit test failure. |
@@ -254,7 +254,8 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, | |||
dispatch_mode = &dispatch_modes[nid]; | |||
if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false; | |||
} | |||
auto finfer = finfer_shape.get(inode.source->op(), fdefault); | |||
auto finfer = (inode.source->op() == Op::Get("_zeros")) ? fdefault : | |||
finfer_shape.get(inode.source->op(), fdefault); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure about this? This affects all _zero ops, not just for the case you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, this is breaking some unit test (however, due to unittest of master branch is broken in MacOS, I wan't able to verify before checkin). I have changed the PR to WIP.
@eric-haibin-lin Please review this new implementation. Thanks for your suggestion! |
What's up with the build? |
@eric-haibin-lin Not sure exactly. An earlier build passed dcc5f78). After I renamed some variables the build on ARM7 failed. I can submit an empty change to trigger the build again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
with autograd.record(): | ||
test64 = test_func(data64) | ||
test64.backward() | ||
assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you set rtol and atol to some bigger value than default here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why increase the rtol and atol if the unit test can pass with the default one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be flaky. you are comparing a float32 numpy to a float64 numpy and the atol and rtol defaults are small.
Also,maybe we should add zeros to APIs that may be good to break for 2.0 #9686 |
@anirudh2290 The _zeros_without_dtype operator is a private operator used only in building nnvm graph. It is not meant to be exposed to users. |
@apeforest what i meant is we can change the dtype default to -1 for zeros operator for 2.0. |
@anirudh2290 Thanks for the clarification. I have increased atol and rtol values in unit test. As to changing the dtype default to -1 for zeros, I think it is not related to this PR and may cause a backward compatibility issue with old models. Therefore, I would prefer doing that in a separate PR. Please let me know what you think. Thanks. |
Not suggesting to do it in this PR. Just wanted to document it in the APIs to break for 2.0 and we can do it before 2.0 release. |
…on (apache#12290) * Fix the dtype mismatch in derived _zeros node * Add unittest for infer dtype * Add one more unit test * Add nose runmodule * Add a zero operator with no default dtype * Rename variables * fix a bug: rename operator for gpu * Increase atol and rtol to avoid flakiness
Description
This PR fixes the issues #9067 and #8799 where gradient computation for operators with multiple output fails in ndarray if the dtype is not float32.
The root cause of the issue is that a _zeros operator was added for the other don't care output. The _zeros operator uses float32 dtype by default and it will cause conflict if the dtype in ndarray is not float32. My solution is to create a new _zeros_without_dtype operator that does not take any default dtype and use it to replace the _zeros operator in the computation graph. This change solves the dtype conflict problem and should be backward compatible.
A unit test is added to test this fix.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments