diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 538d5202942d..c111a95a707a 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,9 +32,9 @@ 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', - 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye', - 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', + 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace', + 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -3029,6 +3029,71 @@ def split(ary, indices_or_sections, axis=0): # pylint: enable=redefined-outer-name +# pylint: disable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def array_split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an array of length l that should be split into n sections, it returns + l % n sub-arrays of size l//n + 1 and the rest of size l//n. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D Python tuple, list or set. + Param used to determine the number and size of the subarray. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Examples + -------- + >>> x = np.arange(9.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])] + + >>> np.array_split(x, [3, 5, 6, 8]) + [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])] + + >>> x = np.arange(8.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])] + + >>> x = np.arange(7.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4.]), array([5., 6.])] + """ + indices = [] + sections = 0 + if isinstance(indices_or_sections, integer_types): + sections = indices_or_sections + elif isinstance(indices_or_sections, (list, set, tuple)): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints') + ret = _npi.split(ary, indices, axis, False, sections) + if not isinstance(ret, list): + return [ret] + return ret +# pylint: enable=redefined-outer-name + + # pylint: disable=redefined-outer-name @set_module('mxnet.ndarray.numpy') def hsplit(ary, indices_or_sections): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index aa0762bf0e3f..5795c62942df 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -50,9 +50,9 @@ 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', - 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', + 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort', - 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', + 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', @@ -4841,6 +4841,58 @@ def split(ary, indices_or_sections, axis=0): return _mx_nd_np.split(ary, indices_or_sections, axis=axis) +@set_module('mxnet.numpy') +def array_split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an array of length l that should be split into n sections, it returns + l % n sub-arrays of size l//n + 1 and the rest of size l//n. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D Python tuple, list or set. + Param used to determine the number and size of the subarray. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Examples + -------- + >>> x = np.arange(9.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])] + + >>> np.array_split(x, [3, 5, 6, 8]) + [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])] + + >>> x = np.arange(8.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])] + + >>> x = np.arange(7.0) + >>> np.array_split(x, 3) + [array([0., 1., 2.]), array([3., 4.]), array([5., 6.])] + """ + return _mx_nd_np.array_split(ary, indices_or_sections, axis=axis) + + @set_module('mxnet.numpy') def vsplit(ary, indices_or_sections): r""" diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1a238ec2c7c7..23593a47e6ba 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -113,6 +113,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'reshape', 'roll', 'split', + 'array_split', 'squeeze', 'stack', 'std', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 4b06bbec7cae..c61d5b2d393d 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,9 +40,9 @@ 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye', - 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', + 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace', + 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -3116,6 +3116,54 @@ def split(ary, indices_or_sections, axis=0): # pylint: enable=redefined-outer-name +# pylint: disable=redefined-outer-name +@set_module('mxnet.symbol.numpy') +def array_split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an array of length l that should be split into n sections, it returns + l % n sub-arrays of size l//n + 1 and the rest of size l//n. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + + Parameters + ---------- + ary : _Symbol + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D Python tuple, list or set. + Param used to determine the number and size of the subarray. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + """ + indices = [] + sections = 0 + if isinstance(indices_or_sections, int): + sections = indices_or_sections + elif isinstance(indices_or_sections, (list, set, tuple)): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple / list / set of ints') + ret = _npi.split(ary, indices, axis, False, sections) + if not isinstance(ret, list): + return [ret] + return ret +# pylint: enable=redefined-outer-name + + # pylint: disable=redefined-outer-name @set_module('mxnet.symbol.numpy') def hsplit(ary, indices_or_sections): diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4bd059ae81df..0c501808a6c0 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2727,9 +2727,15 @@ struct SplitParam : public dmlc::Parameter { inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) { mxnet::TShape indices(sections+1, -1); indices[0] = 0; - int64_t section_size = ishape[axis] / sections; + int64_t section_size_b = (int64_t) (ishape[axis] / sections); + int64_t section_size_a = section_size_b + 1; + int section_a = ishape[axis] % sections; for (int i = 0; i < sections; ++i) { - indices[i+1] = section_size * (i + 1); + if ( i < section_a ) { + indices[i+1] = section_size_a * (i + 1); + } else { + indices[i+1] = section_size_b + indices[i]; + } } return indices; } diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 4c4e8b90eca9..0e875825a699 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -181,6 +181,18 @@ def _add_workload_split(): assertRaises(ValueError, np.split, np.arange(10), 3) +def _add_workload_array_split(): + a = np.arange(10) + b = np.array([np.arange(10), np.arange(10)]) + + for i in range(1, 12): + OpArgMngr.add_workload('array_split', a, i) + OpArgMngr.add_workload('array_split', b, 3, axis=0) + OpArgMngr.add_workload('array_split', b, [0, 1, 2], axis=0) + OpArgMngr.add_workload('array_split', b, 3, axis=-1) + OpArgMngr.add_workload('array_split', b, 3) + + def _add_workload_squeeze(): OpArgMngr.add_workload('squeeze', np.random.uniform(size=(4, 1))) OpArgMngr.add_workload('squeeze', np.random.uniform(size=(20, 10, 10, 1, 1))) @@ -1398,6 +1410,7 @@ def _prepare_workloads(): _add_workload_rint(array_pool) _add_workload_roll() _add_workload_split() + _add_workload_array_split() _add_workload_squeeze() _add_workload_stack(array_pool) _add_workload_std() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 078e37fc4146..54bce0aaae84 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2237,6 +2237,63 @@ def get_indices(axis_size): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_array_split(): + class TestArray_split(HybridBlock): + def __init__(self, indices_or_sections, axis=None): + super(TestArray_split, self).__init__() + self._axis = axis + self._indices_or_sections = indices_or_sections + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.array_split(a, indices_or_sections=self._indices_or_sections, + axis=self._axis) + + def get_indices(axis_size): + if axis_size is 0: + axis_size = random.randint(3, 6) + samples = random.randint(1, axis_size - 1) + indices = sorted(random.sample([i for i in range(0, axis_size + 1)], samples)) + indices = tuple(indices) + return indices + + shapes = [(), (5, ), (10, ), + (2, 5), (5, 5), (10, 10), + (4, 4, 4), (4, 6, 9), (6, 6, 6), + (7, 8, 9, 10)] + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64] + + combinations = itertools.product([False, True], shapes, dtypes) + for hybridize, shape, dtype in combinations: + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 1e-4 if dtype == np.float16 else 1e-5 + for axis in range(len(shape)): + x = np.random.uniform(-5.0, 5.0, size=shape).astype(dtype) + indices = get_indices(shape[axis]) + sections = 7 if x.shape[axis] is 0 else random.randint(1,x.shape[axis]) + for indices_or_sections in [indices, sections]: + # test gluon + test_array_split = TestArray_split(axis=axis, indices_or_sections=indices_or_sections) + if hybridize: + test_array_split.hybridize() + x.attach_grad() + expected_ret = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + with mx.autograd.record(): + y = test_array_split(x) + assert len(y) == len(expected_ret) + for mx_out, np_out in zip(y, expected_ret): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx.autograd.backward(y) + assert_almost_equal(x.grad.asnumpy(), _np.ones(x.shape), rtol=rtol, atol=atol) + + # test imperative + mx_outs = np.array_split(x, indices_or_sections=indices_or_sections, axis=axis) + np_outs = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + for mx_out, np_out in zip(mx_outs, np_outs): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_vsplit():