diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index d96782a27ebf..fc1b351c649a 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -780,7 +780,17 @@ namespace isnan_typed { MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0)); -MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0)); +/*! \brief used for computing gradient of relu operator */ +struct relu_grad : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(Dtype a) { + if (isnan_typed::IsNan(a)) { + return a; + } else { + return a > DType(0) ? DType(1) : DType(0); + } + } +} /*! \brief used for computing binary operator maximum */ struct maximum : public mxnet_op::tunable { diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index edf61c12672b..3a17a1e89461 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1571,6 +1571,14 @@ def test_ndarray_nan_comparison(): np_relu = np.maximum(data1.asnumpy(), 0) np.testing.assert_equal(nd_relu.asnumpy(), np_relu) + data1.attach_grad() + with mx.autograd.record(): + y = mx.nd.relu(data1) + y.backward() + data1_grad = data1.grad.asnumpy() + for i in (np.isnan(data1_grad))[1][0].flatten(): + assert i == True + if __name__ == '__main__': import nose nose.runmodule()