Skip to content

Commit

Permalink
Revert the change broadcast_to param shape (apache#14998)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce authored and haohuw committed Jun 23, 2019
1 parent 70c2771 commit 4e09157
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,

inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
Expand All @@ -389,7 +389,7 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
<< "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape;
mxnet::TShape oshape = param.shape;
for (int i = 0; i < ishape.ndim(); ++i) {
if (oshape[i] != -1) {
if (oshape[i] != 0) {
CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
<< "Array cannot be broadcasted from " << ishape << " to " << param.shape;
} else {
Expand Down
13 changes: 11 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,26 +2535,35 @@ def test_broadcast():
size = tuple([shape[ele] for ele in axis])
for ele in axis:
shape[ele] = 1
target_shape_with_zero = list(target_shape)
for idx in range(len(target_shape_with_zero)):
if idx not in axis:
target_shape_with_zero[idx] = 0
break

a = mx.symbol.Variable('a')
sym_bcast_axis = mx.symbol.broadcast_axis(a, axis=axis, size=size)
sym_bcast_to = mx.symbol.broadcast_to(a, shape=tuple(target_shape))
sym_bcast_to_with_zero = mx.symbol.broadcast_to(a, shape=tuple(target_shape_with_zero))
sym_bcast_like = mx.symbol.broadcast_like(a, sym_bcast_to)

def test_broadcasting_ele(sym_bcast):
dat_npy = np.random.rand(*shape)
groundtruth = dat_npy
grad_nd = mx.nd.empty(shape)
outgrad_npy = np.random.rand(*target_shape)
grad_groundtruth = np_reduce(outgrad_npy, axis=axis, keepdims=True,
numpy_reduce_func=np.sum)
numpy_reduce_func=np.sum)
net = sym_bcast.bind(default_context(), args={'a': mx.nd.array(dat_npy)},
args_grad={'a': grad_nd})
args_grad={'a': grad_nd})
net.forward(is_train=True)
assert (net.outputs[0].shape == target_shape).all()
assert_almost_equal(net.outputs[0].asnumpy(), groundtruth, rtol=1e-4)
net.backward(out_grads=mx.nd.array(outgrad_npy))
assert_almost_equal(grad_nd.asnumpy(), grad_groundtruth, rtol=1e-4)
test_broadcasting_ele(sym_bcast_axis)
test_broadcasting_ele(sym_bcast_to)
test_broadcasting_ele(sym_bcast_to_with_zero)
test_broadcasting_ele(sym_bcast_like)


Expand Down

0 comments on commit 4e09157

Please sign in to comment.