Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
use identity_with_cast (#16913)
Browse files Browse the repository at this point in the history
change the doc
move shape and dtype inference func to .cc file
fix format
fix bug in test
fix bug in MXNET_LAPACK_FSIG_GESV
fix format
fix undefined #gesv
  • Loading branch information
DwwWxx authored and haojin2 committed Dec 9, 2019
1 parent 7736bfd commit 71b6272
Show file tree
Hide file tree
Showing 11 changed files with 980 additions and 5 deletions.
56 changes: 55 additions & 1 deletion python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _op as _mx_nd_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -352,3 +352,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)


def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.
Parameters
----------
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.
Raises
------
MXNetError
If `a` is singular or not square.
Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.
The solutions are computed using LAPACK routine ``_gesv``.
`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.
Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
56 changes: 55 additions & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import absolute_import
from ..ndarray import numpy as _mx_nd_np

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -370,3 +370,57 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _mx_nd_np.linalg.slogdet(a)


def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.
Parameters
----------
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.
Raises
------
MXNetError
If `a` is singular or not square.
Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.
The solutions are computed using LAPACK routine ``_gesv``.
`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.
Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _mx_nd_np.linalg.solve(a, b)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'linalg.norm',
'linalg.cholesky',
'linalg.inv',
'linalg.solve',
'shape',
'trace',
'tril',
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _op as _mx_sym_np
from . import _internal as _npi

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet']
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve']


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -339,3 +339,56 @@ def slogdet(a):
(1., -1151.2925464970228)
"""
return _npi.slogdet(a)

def solve(a, b):
r"""
Solve a linear matrix equation, or system of linear scalar equations.
Computes the "exact" solution, `x`, of the well-determined, i.e., full
rank, linear matrix equation `ax = b`.
Parameters
----------
a : (..., M, M) ndarray
Coefficient matrix.
b : {(..., M,), (..., M, K)}, ndarray
Ordinate or "dependent variable" values.
Returns
-------
x : {(..., M,), (..., M, K)} ndarray
Solution to the system a x = b. Returned shape is identical to `b`.
Raises
------
MXNetError
If `a` is singular or not square.
Notes
-----
Broadcasting rules apply, see the `numpy.linalg` documentation for
details.
The solutions are computed using LAPACK routine ``_gesv``.
`a` must be square and of full-rank, i.e., all rows (or, equivalently,
columns) must be linearly independent; if either is not true, use
`lstsq` for the least-squares best "solution" of the
system/equation.
Examples
--------
Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``:
>>> a = np.array([[3,1], [1,2]])
>>> b = np.array([9,8])
>>> x = np.linalg.solve(a, b)
>>> x
array([2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
10 changes: 10 additions & 0 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@
return 1; \
}

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
Expand Down Expand Up @@ -101,4 +108,7 @@
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#endif // MSHADOW_USE_MKL == 0
39 changes: 37 additions & 2 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ extern "C" {

MXNET_LAPACK_FSIG_GETRI(sgetri, float)
MXNET_LAPACK_FSIG_GETRI(dgetri, double)

#ifdef __ANDROID__
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
int func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#else
#define MXNET_LAPACK_FSIG_GESV(func, dtype) \
void func##_(int *n, int *nrhs, dtype *a, int *lda, \
int *ipiv, dtype *b, int *ldb, int *info);
#endif

MXNET_LAPACK_FSIG_GESV(sgesv, float)
MXNET_LAPACK_FSIG_GESV(dgesv, double)
}

#endif // MSHADOW_USE_MKL == 0
Expand Down Expand Up @@ -197,6 +210,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_dpotri LAPACKE_dpotri
#define mxnet_lapack_sposv LAPACKE_sposv
#define mxnet_lapack_dposv LAPACKE_dposv
#define MXNET_LAPACK_dgesv LAPACKE_dgesv
#define MXNET_LAPACK_sgesv LAPACKE_sgesv

// The following functions differ in signature from the
// MXNET_LAPACK-signature and have to be wrapped.
Expand Down Expand Up @@ -440,9 +455,23 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GETRI(s, float)
MXNET_LAPACK_CWRAP_GETRI(d, double)

#else

#define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \
int n, int nrhs, dtype *a, int lda, \
int *ipiv, dtype *b, int ldb) { \
if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \
return 1; \
} else { \
int info(0); \
prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \
return info; \
} \
}
MXNET_LAPACK_CWRAP_GESV(s, float)
MXNET_LAPACK_CWRAP_GESV(d, double)

#else

#define MXNET_LAPACK_ROW_MAJOR 101
#define MXNET_LAPACK_COL_MAJOR 102
Expand Down Expand Up @@ -473,6 +502,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
int ldut, dtype* s, dtype* v, int ldv, \
dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER7(func, dtype) \
int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
int lda, int *ipiv, dtype *b, int ldb); \

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
Expand Down Expand Up @@ -501,6 +533,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAPPER6(sgesvd, float)
MXNET_LAPACK_CWRAPPER6(dgesvd, double)

MXNET_LAPACK_CWRAPPER7(sgesv, float)
MXNET_LAPACK_CWRAPPER7(dgesv, double)

#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
Expand Down
Loading

0 comments on commit 71b6272

Please sign in to comment.