From 3f22504dc84cb5af4492a23cdd6778e072e02eb1 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 5 Mar 2019 15:26:32 -0800 Subject: [PATCH] Relaxing type requirements for reshape_like op (#14325) * Relax type requirements in reshape_like * Add test * Fix lint * Retrigger CI --- src/operator/tensor/elemwise_unary_op_basic.cc | 11 ++++++++++- tests/python/unittest/test_operator.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 4aaf4dfd33c4..19a9ac8359eb 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -481,7 +481,16 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or ` [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FInferShape", ReshapeLikeShapeCompute) -.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; + std::vector checked_in_attrs = { (*in_attrs)[0] }; + bool ret = !type_is_none((*in_attrs)[1]) && + ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs); + (*in_attrs)[0] = checked_in_attrs[0]; + return ret; + }) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 500d2f99f4d9..0ac530c23d11 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2529,6 +2529,16 @@ def test_slice_like_different_types(): z = mx.nd.slice_like(x, y) assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]]) +@with_seed() +def test_reshape_like_different_types(): + x = mx.nd.zeros((2, 3)) + + y = mx.nd.array([[1, 2], [3, 4], [5, 6]]) + + y = mx.nd.array(y).astype('int32') + z = mx.nd.reshape_like(x, y) + assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]]) + @with_seed() def test_flip(): for ndim in range(1, 6):