diff --git a/src/operator/tensor/ravel.cc b/src/operator/tensor/ravel.cc index 0a66ea80fca9..94d79c7d07a6 100644 --- a/src/operator/tensor/ravel.cc +++ b/src/operator/tensor/ravel.cc @@ -31,12 +31,13 @@ DMLC_REGISTER_PARAMETER(RavelParam); NNVM_REGISTER_OP(_ravel_multi_index) .add_alias("ravel_multi_index") -.describe(R"code(Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. +.describe(R"code(Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. The leading dimension may be left unspecified by using -1 as placeholder. Examples:: A = [[3,6,6],[4,5,1]] ravel(A, shape=(7,6)) = [22,41,37] + ravel(A, shape=(-1,6)) = [22,41,37] )code" ADD_FILELINE) .set_num_inputs(1) @@ -55,12 +56,13 @@ Examples:: NNVM_REGISTER_OP(_unravel_index) .add_alias("unravel_index") -.describe(R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. +.describe(R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. The leading dimension may be left unspecified by using -1 as placeholder. Examples:: A = [22,41,37] unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]] + unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]] )code" ADD_FILELINE) .set_num_inputs(1) diff --git a/src/operator/tensor/ravel.h b/src/operator/tensor/ravel.h index 6d337dcef701..256fe334e971 100644 --- a/src/operator/tensor/ravel.h +++ b/src/operator/tensor/ravel.h @@ -110,11 +110,12 @@ struct unravel_index { DType *unravelled, DType *ravelled) { index_t idx(ravelled[i]); #pragma unroll - for (int j = ndim; j--; ) { + for (int j = ndim-1; j > 0; --j) { index_t tmp = idx / shape[j]; unravelled[i+j*N] = idx - tmp*shape[j]; idx = tmp; } + unravelled[i] = idx; } }; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6bb815066c80..7169395205e0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7106,6 +7106,13 @@ def test_ravel(): check_symbolic_forward(b, location={'a': data}, expected=[ravel_npy]) c = mx.sym.unravel_index(a, shape=shape) check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data]) + # Test with leading dimension set to -1. + shape2 = shape + shape2 = (-1,)+shape[1:] + b = mx.sym.ravel_multi_index(a, shape=shape2) + check_symbolic_forward(b, location={'a': data}, expected=[ravel_npy]) + c = mx.sym.unravel_index(a, shape=shape2) + check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data]) def test_context_num_gpus(): try: