Skip to content

Commit

Permalink
[Numpy] FFI for diag/diagonal/diag_indices_from (apache#17789)
Browse files Browse the repository at this point in the history
* ffi_diag/diagonal/diag_indices_from

* sanity && benchmark
  • Loading branch information
Tommliu authored and MoisesHer committed Apr 10, 2020
1 parent 93dc550 commit 4bc950b
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 34 deletions.
12 changes: 12 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def prepare_workloads():
OpArgMngr.add_workload("fmin", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("fmod", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
OpArgMngr.add_workload("diag", pool['2x2'], k=1)
OpArgMngr.add_workload("diagonal", pool['2x2x2'], offset=-1, axis1=0, axis2=1)
OpArgMngr.add_workload("diag_indices_from", pool['2x2'])
OpArgMngr.add_workload("bincount", dnp.arange(3, dtype=int), pool['3'], minlength=4)
OpArgMngr.add_workload("percentile", pool['2x2x2'], 80, axis=0, out=pool['2x2'],\
interpolation='midpoint')
OpArgMngr.add_workload("quantile", pool['2x2x2'], 0.8, axis=0, out=pool['2x2'],\
interpolation='midpoint')
OpArgMngr.add_workload("all", pool['2x2x2'], axis=(0, 1),\
out=dnp.array([False, False], dtype=bool), keepdims=False)
OpArgMngr.add_workload("any", pool['2x2x2'], axis=(0, 1),\
out=dnp.array([False, False], dtype=bool), keepdims=False)
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
OpArgMngr.add_workload("rot90", pool["2x2"], 2)

Expand Down
106 changes: 99 additions & 7 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5130,6 +5130,7 @@ def ravel(x, order='C'):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-name
"""
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Expand Down Expand Up @@ -5159,11 +5160,7 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-
if order == 'C':
if isinstance(indices, numeric_types):
return _np.unravel_index(indices, shape)
ret = _npi.unravel_index_fallback(indices, shape=shape)
ret_list = []
for item in ret:
ret_list += [item]
return tuple(ret_list)
return tuple(_npi.unravel_index_fallback(indices, shape=shape))
else:
raise NotImplementedError('Do not support column-major (Fortran-style) order at this moment')

Expand Down Expand Up @@ -5207,6 +5204,7 @@ def flatnonzero(a):
return nonzero(ravel(a))[0]


@set_module('mxnet.ndarray.numpy')
def diag_indices_from(arr):
"""
This returns a tuple of indices that can be used to access the main diagonal of an array
Expand Down Expand Up @@ -5243,7 +5241,7 @@ def diag_indices_from(arr):
[ 8, 9, 100, 11],
[ 12, 13, 14, 100]])
"""
return tuple(_npi.diag_indices_from(arr))
return tuple(_api_internal.diag_indices_from(arr))


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -7941,3 +7939,97 @@ def cumsum(a, axis=None, dtype=None, out=None):
[ 4, 9, 15]])
"""
return _api_internal.cumsum(a, axis, dtype, out)


@set_module('mxnet.ndarray.numpy')
def diag(v, k=0):
"""
Extracts a diagonal or constructs a diagonal array.
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- 2-D arrays: extracts the k-th Diagonal
Parameters
----------
array : ndarray
The array to apply diag method.
k : offset
extracts or constructs kth diagonal given input array
Returns
----------
out : ndarray
The extracted diagonal or constructed diagonal array.
Examples
--------
>>> x = np.arange(9).reshape((3,3))
>>> x
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> np.diag(x)
array([0, 4, 8])
>>> np.diag(x, k=1)
array([1, 5])
>>> np.diag(x, k=-1)
array([3, 7])
>>> np.diag(np.diag(x))
array([[0, 0, 0],
[0, 4, 0],
[0, 0, 8]])
"""
return _api_internal.diag(v, k)


@set_module('mxnet.ndarray.numpy')
def diagonal(a, offset=0, axis1=0, axis2=1):
"""
If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of
the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the
resulting array can be determined by removing axis1 and axis2 and appending an index to the
right equal to the size of the resulting diagonals.
Parameters
----------
a : ndarray
Input data from which diagonal are taken.
offset: int, Optional
Offset of the diagonal from the main diagonal
axis1: int, Optional
Axis to be used as the first axis of the 2-D sub-arrays
axis2: int, Optional
Axis to be used as the second axis of the 2-D sub-arrays
Returns
-------
out : ndarray
Output result
Raises
-------
ValueError: If the dimension of a is less than 2.
Examples
--------
>>> a = np.arange(4).reshape(2,2)
>>> a
array([[0, 1],
[2, 3]])
>>> np.diagonal(a)
array([0, 3])
>>> np.diagonal(a, 1)
array([1])
>>> a = np.arange(8).reshape(2,2,2)
>>>a
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.diagonal(a, 0, 0, 1)
array([[0, 6],
[1, 7]])
"""
return _api_internal.diagonal(a, offset, axis1, axis2)
96 changes: 95 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum']
'pad', 'cumsum', 'diag', 'diagonal']

__all__ += fallback.__all__

Expand Down Expand Up @@ -10102,3 +10102,97 @@ def cumsum(a, axis=None, dtype=None, out=None):
"""
return _mx_nd_np.cumsum(a, axis=axis, dtype=dtype, out=out)
# pylint: enable=redefined-outer-name


@set_module('mxnet.numpy')
def diag(v, k=0):
"""
Extracts a diagonal or constructs a diagonal array.
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- 2-D arrays: extracts the k-th Diagonal
Parameters
----------
array : ndarray
The array to apply diag method.
k : offset
extracts or constructs kth diagonal given input array
Returns
----------
out : ndarray
The extracted diagonal or constructed diagonal array.
Examples
--------
>>> x = np.arange(9).reshape((3,3))
>>> x
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> np.diag(x)
array([0, 4, 8])
>>> np.diag(x, k=1)
array([1, 5])
>>> np.diag(x, k=-1)
array([3, 7])
>>> np.diag(np.diag(x))
array([[0, 0, 0],
[0, 4, 0],
[0, 0, 8]])
"""
return _mx_nd_np.diag(v, k=k)


@set_module('mxnet.numpy')
def diagonal(a, offset=0, axis1=0, axis2=1):
"""
If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of
the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the
resulting array can be determined by removing axis1 and axis2 and appending an index to the
right equal to the size of the resulting diagonals.
Parameters
----------
a : ndarray
Input data from which diagonal are taken.
offset: int, Optional
Offset of the diagonal from the main diagonal
axis1: int, Optional
Axis to be used as the first axis of the 2-D sub-arrays
axis2: int, Optional
Axis to be used as the second axis of the 2-D sub-arrays
Returns
-------
out : ndarray
Output result
Raises
-------
ValueError: If the dimension of a is less than 2.
Examples
--------
>>> a = np.arange(4).reshape(2,2)
>>> a
array([[0, 1],
[2, 3]])
>>> np.diagonal(a)
array([0, 3])
>>> np.diagonal(a, 1)
array([1])
>>> a = np.arange(8).reshape(2,2,2)
>>>a
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.diagonal(a, 0, 0, 1)
array([[0, 6],
[1, 7]])
"""
return _mx_nd_np.diagonal(a, offset=offset, axis1=axis1, axis2=axis2)
56 changes: 55 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
'where', 'bincount', 'pad', 'cumsum', 'diag', 'diagonal']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -6968,4 +6968,58 @@ def cumsum(a, axis=None, dtype=None, out=None):
return _npi.cumsum(a, axis=axis, dtype=dtype, out=out)


@set_module('mxnet.symbol.numpy')
def diag(v, k=0):
"""
Extracts a diagonal or constructs a diagonal array.
- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero.
- 2-D arrays: extracts the k-th Diagonal
Parameters
----------
array : _Symbol
The array to apply diag method.
k : offset
extracts or constructs kth diagonal given input array
Returns
----------
out : _Symbol
The extracted diagonal or constructed diagonal array.
"""
return _npi.diag(v, k=k)


@set_module('mxnet.symbol.numpy')
def diagonal(a, offset=0, axis1=0, axis2=1):
"""
If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of
the form a[i, i+offset]. If a has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the
resulting array can be determined by removing axis1 and axis2 and appending an index to the
right equal to the size of the resulting diagonals.
Parameters
----------
a : _Symbol
Input data from which diagonal are taken.
offset: int, Optional
Offset of the diagonal from the main diagonal
axis1: int, Optional
Axis to be used as the first axis of the 2-D sub-arrays
axis2: int, Optional
Axis to be used as the second axis of the 2-D sub-arrays
Returns
-------
out : _Symbol
Output result
Raises
-------
ValueError: If the dimension of a is less than 2.
"""
return _npi.diagonal(a, offset=offset, axis1=axis1, axis2=axis2)


_set_np_symbol_class(_Symbol)
49 changes: 49 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,53 @@ MXNET_REGISTER_API("_npi.rot90")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.diag")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diag");
nnvm::NodeAttrs attrs;
op::NumpyDiagParam param;
param.k = args[1].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyDiagParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.diagonal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diagonal");
nnvm::NodeAttrs attrs;
op::NumpyDiagonalParam param;
param.offset = args[1].operator int();
param.axis1 = args[2].operator int();
param.axis2 = args[3].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyDiagonalParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.diag_indices_from")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_diag_indices_from");
nnvm::NodeAttrs attrs;
attrs.op = op;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
Loading

0 comments on commit 4bc950b

Please sign in to comment.