Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix relu grad
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Acharya committed Mar 7, 2019
1 parent 279ed8b commit 7dbf2bb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,17 @@ namespace isnan_typed {

MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
/*! \brief used for computing gradient of relu operator */
struct relu_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return a > DType(0) ? DType(1) : DType(0);
}
}
};

/*! \brief used for computing binary operator maximum */
struct maximum : public mxnet_op::tunable {
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,14 @@ def test_ndarray_nan_comparison():
np_relu = np.maximum(data1.asnumpy(), 0)
np.testing.assert_equal(nd_relu.asnumpy(), np_relu)

data1.attach_grad()
with mx.autograd.record():
y = mx.nd.relu(data1)
y.backward()
data1_grad = data1.grad.asnumpy()
for i in (np.isnan(data1_grad))[1][0].flatten():
assert i == True

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 7dbf2bb

Please sign in to comment.