Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

split_v2 operator #13687

Merged
merged 1 commit into from
Jan 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
"split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -1133,6 +1133,14 @@ def split(self, *args, **kwargs):
"""
return op.split(self, *args, **kwargs)

def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.

The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
return split_v2(self, *args, **kwargs)
HyperZealot marked this conversation as resolved.
Show resolved Hide resolved

def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.

Expand Down Expand Up @@ -3901,6 +3909,12 @@ def histogram(a, bins=10, range=None):
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.

Returns
-------
NDArray
A created array.

"""

# pylint: disable= no-member, protected-access
Expand All @@ -3916,6 +3930,51 @@ def histogram(a, bins=10, range=None):
raise ValueError("bins argument should be either an integer or an NDArray")
# pylint: enable= no-member, protected-access, redefined-builtin

def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
"""Split an array into multiple sub-arrays.

Parameters
----------
ary : NDArray
Array to be divided into sub-arrays.
indices_or_sections : int or tuple of ints
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
squeeze_axis: boolean, optional
Whether to squeeze the axis of sub-arrays or not, only useful when size
of the sub-arrays are 1 on the `axis`. Default is False.

Returns
-------
NDArray
A created array.

"""
indices = []
axis_size = ary.shape[axis]
if isinstance(indices_or_sections, int):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _internal._split_v2(ary, indices, axis, squeeze_axis)

PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram"]
"histogram", "split_v2"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -1855,6 +1855,14 @@ def split(self, *args, **kwargs):
"""
return op.split(self, *args, **kwargs)

def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.
The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
return split_v2(self, *args, **kwargs)

def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.
Expand Down Expand Up @@ -2958,6 +2966,11 @@ def histogram(a, bins=10, range=None, **kwargs):
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.
Returns
-------
out : Symbol
The created Symbol
"""
if isinstance(bins, Symbol):
return _internal._histogram(data=a, bins=bins, **kwargs)
Expand All @@ -2967,4 +2980,44 @@ def histogram(a, bins=10, range=None, **kwargs):
return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs)
raise ValueError("bins argument should be either an integer or an NDArray")

def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : NDArray
Array to be divided into sub-arrays.
indices_or_sections : int or tuple of ints
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
squeeze_axis: boolean, optional
Whether to squeeze the axis of sub-arrays or not, only useful when size
of the sub-arrays are 1 on the `axis`. Default is False.
Returns
-------
out : Symbol
The created Symbol
"""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _internal._split_v2(ary, indices, axis, squeeze_axis, sections)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the symbol implementation and ndarray implementation are different? the ndarray impl always computes indices and here it may pass the number of sections directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not know the shape of a symbol so you cannot pre-compute the indices in symbol mode.


_set_symbol_class(Symbol)
Loading