Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

OP Array_split [Numpy] #17032

Merged
merged 1 commit into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -3029,6 +3029,71 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

Examples
--------
>>> x = np.arange(9.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]

>>> np.array_split(x, [3, 5, 6, 8])
[array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])]

>>> x = np.arange(8.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]

>>> x = np.arange(7.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4.]), array([5., 6.])]
"""
indices = []
sections = 0
if isinstance(indices_or_sections, integer_types):
sections = indices_or_sections
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def hsplit(ary, indices_or_sections):
Expand Down
56 changes: 54 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign',
'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
Expand Down Expand Up @@ -4841,6 +4841,58 @@ def split(ary, indices_or_sections, axis=0):
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

Examples
--------
>>> x = np.arange(9.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]

>>> np.array_split(x, [3, 5, 6, 8])
[array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])]

>>> x = np.arange(8.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]

>>> x = np.arange(7.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4.]), array([5., 6.])]
"""
return _mx_nd_np.array_split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def vsplit(ary, indices_or_sections):
r"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'reshape',
'roll',
'split',
'array_split',
'squeeze',
'stack',
'std',
Expand Down
54 changes: 51 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -3116,6 +3116,54 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.

Parameters
----------
ary : _Symbol
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
"""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def hsplit(ary, indices_or_sections):
Expand Down
10 changes: 8 additions & 2 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2727,9 +2727,15 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) {
mxnet::TShape indices(sections+1, -1);
indices[0] = 0;
int64_t section_size = ishape[axis] / sections;
int64_t section_size_b = (int64_t) (ishape[axis] / sections);
int64_t section_size_a = section_size_b + 1;
int section_a = ishape[axis] % sections;
for (int i = 0; i < sections; ++i) {
indices[i+1] = section_size * (i + 1);
if ( i < section_a ) {
indices[i+1] = section_size_a * (i + 1);
} else {
indices[i+1] = section_size_b + indices[i];
}
}
return indices;
}
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ def _add_workload_split():
assertRaises(ValueError, np.split, np.arange(10), 3)


def _add_workload_array_split():
a = np.arange(10)
b = np.array([np.arange(10), np.arange(10)])

for i in range(1, 12):
OpArgMngr.add_workload('array_split', a, i)
OpArgMngr.add_workload('array_split', b, 3, axis=0)
OpArgMngr.add_workload('array_split', b, [0, 1, 2], axis=0)
OpArgMngr.add_workload('array_split', b, 3, axis=-1)
OpArgMngr.add_workload('array_split', b, 3)


def _add_workload_squeeze():
OpArgMngr.add_workload('squeeze', np.random.uniform(size=(4, 1)))
OpArgMngr.add_workload('squeeze', np.random.uniform(size=(20, 10, 10, 1, 1)))
Expand Down Expand Up @@ -1398,6 +1410,7 @@ def _prepare_workloads():
_add_workload_rint(array_pool)
_add_workload_roll()
_add_workload_split()
_add_workload_array_split()
_add_workload_squeeze()
_add_workload_stack(array_pool)
_add_workload_std()
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,63 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_array_split():
class TestArray_split(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestArray_split, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.array_split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(0, axis_size + 1)], samples))
indices = tuple(indices)
return indices

shapes = [(), (5, ), (10, ),
(2, 5), (5, 5), (10, 10),
(4, 4, 4), (4, 6, 9), (6, 6, 6),
(7, 8, 9, 10)]
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64]

combinations = itertools.product([False, True], shapes, dtypes)
for hybridize, shape, dtype in combinations:
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
for axis in range(len(shape)):
x = np.random.uniform(-5.0, 5.0, size=shape).astype(dtype)
indices = get_indices(shape[axis])
sections = 7 if x.shape[axis] is 0 else random.randint(1,x.shape[axis])
for indices_or_sections in [indices, sections]:
# test gluon
test_array_split = TestArray_split(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_array_split.hybridize()
x.attach_grad()
expected_ret = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_array_split(x)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
mx.autograd.backward(y)
assert_almost_equal(x.grad.asnumpy(), _np.ones(x.shape), rtol=rtol, atol=atol)

# test imperative
mx_outs = np.array_split(x, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_vsplit():
Expand Down