From 1d05575c18958a7c7c1210981378dffa80d9d031 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 27 Feb 2019 09:12:09 +0000 Subject: [PATCH] nan comparison --- src/operator/mshadow_op.h | 38 ++++++++++++++++++++------- tests/python/unittest/test_ndarray.py | 21 +++++++++++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index f56436b8fa0c..c4ebf0d30069 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -127,10 +127,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a))); MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a))); -MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0)); - -MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0)); - MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) * (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a)))); @@ -317,12 +313,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); -/*! \brief used for generate element of maximum */ -MXNET_BINARY_MATH_OP(maximum, a > b ? a : b); - -/*! \brief used for generate element of minimum */ -MXNET_BINARY_MATH_OP_NC(minimum, a < b ? a : b); - MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1)); MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0)); @@ -788,6 +778,34 @@ namespace isnan_typed { } }; // namespace isnan_typed +MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0)); + +MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0)); + +/*! \brief used for binary operator maximum */ +struct maximum : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (isnan_typed::IsNan(a)) { + return a; + } else { + return (a > b ? a : b); + } + } +}; + +/*! \brief used for binary operator minimum */ +struct minimum : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (isnan_typed::IsNan(a)) { + return a; + } else { + return DType(a < b ? a : b); + } + } +}; + /*! \brief sum reducer that ignores NaN values in the input */ struct nansum { /*! \brief do reduction into dst */ diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7176b1888607..edf61c12672b 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1550,6 +1550,27 @@ def test_ndarray_is_nan(): np.testing.assert_equal(output.asnumpy(), expected_output.astype(int)) # astype since numpy functions default return type is boolean array instead of int +@with_seed() +def test_ndarray_nan_comparison(): + random_dimensions = np.random.randint(2, 5) + random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)] + data1 = mxnet.test_utils.rand_ndarray(random_shape,'default') + data2 = mxnet.test_utils.rand_ndarray(random_shape,'default') + data1[1][0] = np.NaN + data2[0][0] = np.NaN + + nd_max = mx.nd.maximum(data1, data2) + np_max = np.maximum(data1.asnumpy(), data2.asnumpy()) + np.testing.assert_equal(nd_max.asnumpy(), np_max) + + nd_min = mx.nd.minimum(data1, data2) + np_min = np.minimum(data1.asnumpy(), data2.asnumpy()) + np.testing.assert_equal(nd_min.asnumpy(), np_min) + + nd_relu = mx.nd.relu(data1) + np_relu = np.maximum(data1.asnumpy(), 0) + np.testing.assert_equal(nd_relu.asnumpy(), np_relu) + if __name__ == '__main__': import nose nose.runmodule()