From 10f3c806f532f706eb6aa0efddca096fd95d6ae1 Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 19 May 2019 20:45:17 -0700 Subject: [PATCH] Revert the change broadcast_to param shape --- src/operator/tensor/broadcast_reduce_op.h | 4 ++-- tests/python/unittest/test_operator.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index f7d9f13fd869..b511619e3bb2 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -379,7 +379,7 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs, inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; @@ -389,7 +389,7 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs, << "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape; mxnet::TShape oshape = param.shape; for (int i = 0; i < ishape.ndim(); ++i) { - if (oshape[i] != -1) { + if (oshape[i] != 0) { CHECK(ishape[i] == oshape[i] || ishape[i] == 1) << "Array cannot be broadcasted from " << ishape << " to " << param.shape; } else { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1768da237daf..23ff11f314be 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2535,19 +2535,27 @@ def test_broadcast(): size = tuple([shape[ele] for ele in axis]) for ele in axis: shape[ele] = 1 + target_shape_with_zero = list(target_shape) + for idx in range(len(target_shape_with_zero)): + if idx not in axis: + target_shape_with_zero[idx] = 0 + break + a = mx.symbol.Variable('a') sym_bcast_axis = mx.symbol.broadcast_axis(a, axis=axis, size=size) sym_bcast_to = mx.symbol.broadcast_to(a, shape=tuple(target_shape)) + sym_bcast_to_with_zero = mx.symbol.broadcast_to(a, shape=tuple(target_shape_with_zero)) sym_bcast_like = mx.symbol.broadcast_like(a, sym_bcast_to) + def test_broadcasting_ele(sym_bcast): dat_npy = np.random.rand(*shape) groundtruth = dat_npy grad_nd = mx.nd.empty(shape) outgrad_npy = np.random.rand(*target_shape) grad_groundtruth = np_reduce(outgrad_npy, axis=axis, keepdims=True, - numpy_reduce_func=np.sum) + numpy_reduce_func=np.sum) net = sym_bcast.bind(default_context(), args={'a': mx.nd.array(dat_npy)}, - args_grad={'a': grad_nd}) + args_grad={'a': grad_nd}) net.forward(is_train=True) assert (net.outputs[0].shape == target_shape).all() assert_almost_equal(net.outputs[0].asnumpy(), groundtruth, rtol=1e-4) @@ -2555,6 +2563,7 @@ def test_broadcasting_ele(sym_bcast): assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4) test_broadcasting_ele(sym_bcast_axis) test_broadcasting_ele(sym_bcast_to) + test_broadcasting_ele(sym_bcast_to_with_zero) test_broadcasting_ele(sym_bcast_like)