diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 351c013d9c04..acb7b283aa76 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2513,10 +2513,10 @@ def moveaxis(tensor, source, destination): ---------- tensor : mx.nd.array The array which axes should be reordered - source : int - Original position of the axes to move. - destination : int - Destination position for each of the original axes. + source : int or sequence of int + Original position of the axes to move. Can be negative but must be unique. + destination : int or sequence of int + Destination position for each of the original axes. Can be negative but must be unique. Returns ------- @@ -2528,19 +2528,32 @@ def moveaxis(tensor, source, destination): >>> X = mx.nd.array([[1, 2, 3], [4, 5, 6]]) >>> mx.nd.moveaxis(X, 0, 1).shape (3L, 2L) + + >>> X = mx.nd.zeros((3, 4, 5)) + >>> mx.nd.moveaxis(X, [0, 1], [-1, -2]).shape + (5, 4, 3) """ - axes = list(range(tensor.ndim)) try: - axes.pop(source) + source = np.core.numeric.normalize_axis_tuple(source, tensor.ndim) except IndexError: raise ValueError('Source should verify 0 <= source < tensor.ndim' 'Got %d' % source) try: - axes.insert(destination, source) + destination = np.core.numeric.normalize_axis_tuple(destination, tensor.ndim) except IndexError: - raise ValueError('Destination should verify 0 <= destination < tensor.ndim' - 'Got %d' % destination) - return op.transpose(tensor, axes) + raise ValueError('Destination should verify 0 <= destination < tensor.ndim (%d).' + % tensor.ndim, 'Got %d' % destination) + + if len(source) != len(destination): + raise ValueError('`source` and `destination` arguments must have ' + 'the same number of elements') + + order = [n for n in range(tensor.ndim) if n not in source] + + for dest, src in sorted(zip(destination, source)): + order.insert(dest, src) + + return op.transpose(tensor, order) # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 3a17a1e89461..2446107ad466 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -32,6 +32,7 @@ from numpy.testing import assert_allclose import mxnet.autograd + def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list=[np.float32]): """check function consistency with uniform random numbers""" if isinstance(arg_shapes, int): @@ -60,6 +61,7 @@ def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list= else: assert_almost_equal(out1, out2, atol=1e-5) + def random_ndarray(dim): shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) data = mx.nd.array(np.random.uniform(-10, 10, shape)) @@ -144,12 +146,14 @@ def test_ndarray_elementwise(): check_with_uniform(mx.nd.square, 1, dim, np.square, rmin=0) check_with_uniform(lambda x: mx.nd.norm(x).asscalar(), 1, dim, np.linalg.norm) + @with_seed() def test_ndarray_elementwisesum(): ones = mx.nd.ones((10,), dtype=np.int32) res = mx.nd.ElementWiseSum(ones, ones*2, ones*4, ones*8) assert same(res.asnumpy(), ones.asnumpy()*15) + @with_seed() def test_ndarray_negate(): npy = np.random.uniform(-10, 10, (2,3,4)) @@ -162,6 +166,7 @@ def test_ndarray_negate(): # we compute (-arr) assert_almost_equal(npy, arr.asnumpy()) + @with_seed() def test_ndarray_reshape(): tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5) @@ -360,6 +365,7 @@ def test_buffer_load(): # test garbage values assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10]) + @with_seed() def test_ndarray_slice(): shape = (10,) @@ -391,6 +397,7 @@ def test_ndarray_slice(): assert same(A[:, i].asnumpy(), A2[:, i]) assert same(A[i, :].asnumpy(), A2[i, :]) + @with_seed() def test_ndarray_crop(): # get crop @@ -524,6 +531,7 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin), mx.nd.argmin, False) + @with_seed() def test_broadcast(): sample_num = 1000 @@ -626,7 +634,7 @@ def check_broadcast_binary(fn): def test_moveaxis(): X = mx.nd.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) - res = mx.nd.moveaxis(X, 0, 3).asnumpy() + res = mx.nd.moveaxis(X, 0, 2).asnumpy() true_res = mx.nd.array([[[ 1., 7.], [ 2., 8.], [ 3., 9.]], @@ -636,6 +644,66 @@ def test_moveaxis(): assert same(res, true_res.asnumpy()) assert mx.nd.moveaxis(X, 2, 0).shape == (3, 2, 2) + def test_move_to_end(): + x = mx.nd.random.normal(0, 1, (5, 6, 7)) + for source, expected in [(0, (6, 7, 5)), + (1, (5, 7, 6)), + (2, (5, 6, 7)), + (-1, (5, 6, 7))]: + actual = mx.nd.moveaxis(x, source, -1).shape + assert actual == expected + + def test_move_new_position(): + x = mx.nd.random.normal(0, 1, (1, 2, 3, 4)) + for source, destination, expected in [ + (0, 1, (2, 1, 3, 4)), + (1, 2, (1, 3, 2, 4)), + (1, -1, (1, 3, 4, 2)), + ]: + actual = mx.nd.moveaxis(x, source, destination).shape + assert actual == expected + + def test_preserve_order(): + x = mx.nd.zeros((1, 2, 3, 4)) + for source, destination in [ + (0, 0), + (3, -1), + (-1, 3), + ([0, -1], [0, -1]), + ([2, 0], [2, 0]), + (range(4), range(4)), + ]: + actual = mx.nd.moveaxis(x, source, destination).shape + assert actual == (1, 2, 3, 4) + + def test_move_multiples(): + x = mx.nd.zeros((4, 1, 2, 3)) + for source, destination, expected in [ + ([0, 1], [2, 3], (2, 3, 4, 1)), + ([2, 3], [0, 1], (2, 3, 4, 1)), + ([0, 1, 2], [2, 3, 0], (2, 3, 4, 1)), + ([3, 0], [1, 0], (4, 3, 1, 2)), + ([0, 3], [0, 1], (4, 3, 1, 2)), + ]: + actual = mx.nd.moveaxis(x, source, destination).shape + assert actual == expected + + def test_errors(): + x = mx.nd.random.normal(0, 1, (1, 2, 3)) + assert_exception(mx.nd.moveaxis, ValueError, x, 3, 0) + assert_exception(mx.nd.moveaxis, ValueError, x, -4, 0) + assert_exception(mx.nd.moveaxis, ValueError, x, 0, 5) + assert_exception(mx.nd.moveaxis, ValueError, x, [0, 0], [0, 1]) + assert_exception(mx.nd.moveaxis, ValueError, x, [0, 1], [1, 1]) + assert_exception(mx.nd.moveaxis, ValueError, x, 0, [0, 1]) + assert_exception(mx.nd.moveaxis, ValueError, x, [0, 1], [0]) + + test_move_to_end() + test_move_new_position() + test_preserve_order() + test_move_multiples() + test_errors() + @with_seed() def test_arange(): @@ -653,6 +721,7 @@ def test_arange(): dtype="int32").asnumpy() assert_almost_equal(pred, gt) + @with_seed() def test_order(): ctx = default_context() @@ -885,6 +954,7 @@ def get_large_matrix(): k=dat_size*dat_size*dat_size*dat_size, is_ascend=False) assert_almost_equal(nd_ret_sort, gt) + @with_seed() def test_ndarray_equal(): x = mx.nd.zeros((2, 3))