From 71b62725a0af8ffa74a76c47589b2473d278db30 Mon Sep 17 00:00:00 2001 From: dw_sjtu <46704444+sjtuWangDing@users.noreply.github.com> Date: Mon, 9 Dec 2019 15:54:21 +0800 Subject: [PATCH] use identity_with_cast (#16913) change the doc move shape and dtype inference func to .cc file fix format fix bug in test fix bug in MXNET_LAPACK_FSIG_GESV fix format fix undefined #gesv --- python/mxnet/ndarray/numpy/linalg.py | 56 +- python/mxnet/numpy/linalg.py | 56 +- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/linalg.py | 55 +- src/operator/c_lapack_api.cc | 10 + src/operator/c_lapack_api.h | 39 +- src/operator/numpy/linalg/np_solve-inl.h | 496 ++++++++++++++++++ src/operator/numpy/linalg/np_solve.cc | 116 ++++ src/operator/numpy/linalg/np_solve.cu | 43 ++ .../unittest/test_numpy_interoperability.py | 22 + tests/python/unittest/test_numpy_op.py | 91 ++++ 11 files changed, 980 insertions(+), 5 deletions(-) create mode 100644 src/operator/numpy/linalg/np_solve-inl.h create mode 100644 src/operator/numpy/linalg/np_solve.cc create mode 100644 src/operator/numpy/linalg/np_solve.cu diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 74ba41f22979..a85c6324f685 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,7 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve'] def norm(x, ord=None, axis=None, keepdims=False): @@ -352,3 +352,57 @@ def slogdet(a): (1., -1151.2925464970228) """ return _npi.slogdet(a) + + +def solve(a, b): + r""" + Solve a linear matrix equation, or system of linear scalar equations. + + Computes the "exact" solution, `x`, of the well-determined, i.e., full + rank, linear matrix equation `ax = b`. + + Parameters + ---------- + a : (..., M, M) ndarray + Coefficient matrix. + b : {(..., M,), (..., M, K)}, ndarray + Ordinate or "dependent variable" values. + + Returns + ------- + x : {(..., M,), (..., M, K)} ndarray + Solution to the system a x = b. Returned shape is identical to `b`. + + Raises + ------ + MXNetError + If `a` is singular or not square. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The solutions are computed using LAPACK routine ``_gesv``. + + `a` must be square and of full-rank, i.e., all rows (or, equivalently, + columns) must be linearly independent; if either is not true, use + `lstsq` for the least-squares best "solution" of the + system/equation. + + Examples + -------- + Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``: + + >>> a = np.array([[3,1], [1,2]]) + >>> b = np.array([9,8]) + >>> x = np.linalg.solve(a, b) + >>> x + array([2., 3.]) + + Check that the solution is correct: + + >>> np.allclose(np.dot(a, x), b) + True + """ + return _npi.solve(a, b) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index fbe3631eb6e6..33d636b7044c 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve'] def norm(x, ord=None, axis=None, keepdims=False): @@ -370,3 +370,57 @@ def slogdet(a): (1., -1151.2925464970228) """ return _mx_nd_np.linalg.slogdet(a) + + +def solve(a, b): + r""" + Solve a linear matrix equation, or system of linear scalar equations. + + Computes the "exact" solution, `x`, of the well-determined, i.e., full + rank, linear matrix equation `ax = b`. + + Parameters + ---------- + a : (..., M, M) ndarray + Coefficient matrix. + b : {(..., M,), (..., M, K)}, ndarray + Ordinate or "dependent variable" values. + + Returns + ------- + x : {(..., M,), (..., M, K)} ndarray + Solution to the system a x = b. Returned shape is identical to `b`. + + Raises + ------ + MXNetError + If `a` is singular or not square. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The solutions are computed using LAPACK routine ``_gesv``. + + `a` must be square and of full-rank, i.e., all rows (or, equivalently, + columns) must be linearly independent; if either is not true, use + `lstsq` for the least-squares best "solution" of the + system/equation. + + Examples + -------- + Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``: + + >>> a = np.array([[3,1], [1,2]]) + >>> b = np.array([9,8]) + >>> x = np.linalg.solve(a, b) + >>> x + array([2., 3.]) + + Check that the solution is correct: + + >>> np.allclose(np.dot(a, x), b) + True + """ + return _mx_nd_np.linalg.solve(a, b) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index a6bceb51cd01..c8b11d85b000 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -131,6 +131,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'linalg.norm', 'linalg.cholesky', 'linalg.inv', + 'linalg.solve', 'shape', 'trace', 'tril', diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index cf33777b2637..1aaf4b990e31 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,7 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve'] def norm(x, ord=None, axis=None, keepdims=False): @@ -339,3 +339,56 @@ def slogdet(a): (1., -1151.2925464970228) """ return _npi.slogdet(a) + +def solve(a, b): + r""" + Solve a linear matrix equation, or system of linear scalar equations. + + Computes the "exact" solution, `x`, of the well-determined, i.e., full + rank, linear matrix equation `ax = b`. + + Parameters + ---------- + a : (..., M, M) ndarray + Coefficient matrix. + b : {(..., M,), (..., M, K)}, ndarray + Ordinate or "dependent variable" values. + + Returns + ------- + x : {(..., M,), (..., M, K)} ndarray + Solution to the system a x = b. Returned shape is identical to `b`. + + Raises + ------ + MXNetError + If `a` is singular or not square. + + Notes + ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The solutions are computed using LAPACK routine ``_gesv``. + + `a` must be square and of full-rank, i.e., all rows (or, equivalently, + columns) must be linearly independent; if either is not true, use + `lstsq` for the least-squares best "solution" of the + system/equation. + + Examples + -------- + Solve the system of equations ``3 * x0 + x1 = 9`` and ``x0 + 2 * x1 = 8``: + + >>> a = np.array([[3,1], [1,2]]) + >>> b = np.array([9,8]) + >>> x = np.linalg.solve(a, b) + >>> x + array([2., 3.]) + + Check that the solution is correct: + + >>> np.allclose(np.dot(a, x), b) + True + """ + return _npi.solve(a, b) diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc index e7a97848700d..73b6138df5ea 100644 --- a/src/operator/c_lapack_api.cc +++ b/src/operator/c_lapack_api.cc @@ -71,6 +71,13 @@ return 1; \ } + #define MXNET_LAPACK_CWRAPPER7(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ + int lda, int *ipiv, dtype *b, int ldb) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ + } + #define MXNET_LAPACK_UNAVAILABLE(func) \ int mxnet_lapack_##func(...) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ @@ -101,4 +108,7 @@ MXNET_LAPACK_CWRAPPER6(sgesvd, float) MXNET_LAPACK_CWRAPPER6(dgesvd, double) + MXNET_LAPACK_CWRAPPER7(sgesv, float) + MXNET_LAPACK_CWRAPPER7(dgesv, double) + #endif // MSHADOW_USE_MKL == 0 diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h index a47bbd0b5857..8a7cbc067feb 100644 --- a/src/operator/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -150,6 +150,19 @@ extern "C" { MXNET_LAPACK_FSIG_GETRI(sgetri, float) MXNET_LAPACK_FSIG_GETRI(dgetri, double) + + #ifdef __ANDROID__ + #define MXNET_LAPACK_FSIG_GESV(func, dtype) \ + int func##_(int *n, int *nrhs, dtype *a, int *lda, \ + int *ipiv, dtype *b, int *ldb, int *info); + #else + #define MXNET_LAPACK_FSIG_GESV(func, dtype) \ + void func##_(int *n, int *nrhs, dtype *a, int *lda, \ + int *ipiv, dtype *b, int *ldb, int *info); + #endif + + MXNET_LAPACK_FSIG_GESV(sgesv, float) + MXNET_LAPACK_FSIG_GESV(dgesv, double) } #endif // MSHADOW_USE_MKL == 0 @@ -197,6 +210,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { #define MXNET_LAPACK_dpotri LAPACKE_dpotri #define mxnet_lapack_sposv LAPACKE_sposv #define mxnet_lapack_dposv LAPACKE_dposv + #define MXNET_LAPACK_dgesv LAPACKE_dgesv + #define MXNET_LAPACK_sgesv LAPACKE_sgesv // The following functions differ in signature from the // MXNET_LAPACK-signature and have to be wrapped. @@ -440,9 +455,23 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAP_GETRI(s, float) MXNET_LAPACK_CWRAP_GETRI(d, double) -#else - + #define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \ + int n, int nrhs, dtype *a, int lda, \ + int *ipiv, dtype *b, int ldb) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \ + return 1; \ + } else { \ + int info(0); \ + prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \ + return info; \ + } \ + } + MXNET_LAPACK_CWRAP_GESV(s, float) + MXNET_LAPACK_CWRAP_GESV(d, double) +#else #define MXNET_LAPACK_ROW_MAJOR 101 #define MXNET_LAPACK_COL_MAJOR 102 @@ -473,6 +502,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { int ldut, dtype* s, dtype* v, int ldv, \ dtype* work, int lwork); + #define MXNET_LAPACK_CWRAPPER7(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ + int lda, int *ipiv, dtype *b, int ldb); \ #define MXNET_LAPACK_UNAVAILABLE(func) \ int mxnet_lapack_##func(...); @@ -501,6 +533,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { MXNET_LAPACK_CWRAPPER6(sgesvd, float) MXNET_LAPACK_CWRAPPER6(dgesvd, double) + MXNET_LAPACK_CWRAPPER7(sgesv, float) + MXNET_LAPACK_CWRAPPER7(dgesv, double) + #undef MXNET_LAPACK_CWRAPPER1 #undef MXNET_LAPACK_CWRAPPER2 #undef MXNET_LAPACK_CWRAPPER3 diff --git a/src/operator/numpy/linalg/np_solve-inl.h b/src/operator/numpy/linalg/np_solve-inl.h new file mode 100644 index 000000000000..03134f8b5688 --- /dev/null +++ b/src/operator/numpy/linalg/np_solve-inl.h @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_solve-inl.h + * \brief Placeholder for solve linear equation + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_ + +#include +#include +#include "../../tensor/la_op.h" +#include "../../tensor/la_op-inl.h" +#include "../../linalg.h" +#include "../../operator_common.h" +#include "../../mshadow_op.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +template +void linalg_solve(const Tensor& A, + const Tensor& X, + const Tensor& ipiv, + Stream *s); + +template +void linalg_batch_solve(const Tensor& A, + const Tensor& X, + const Tensor& ipiv, + const mxnet::OpContext& ctx); + +template inline +int linalg_dn_getrf_workspace_query(const Tensor& A, + Stream *s); + +template inline +void linalg_dn_getrf(const Tensor& A, + const Tensor& ipiv, + Stream *s); + +template inline +void linalg_dn_getrs(const Tensor& A, + const Tensor& X, + const Tensor& ipiv, + Stream *s); + +// kernel for transpose +struct SolveTypeTransposeHelper { + template + MSHADOW_XINLINE static void Map(int i, const InDType *in_data, OutDType *out_data, + const int ncol1, const int ncol2, const int step) { + int idx = i / step, row = (i % step) / ncol1, col = (i % step) % ncol1; + out_data[idx * step + row + col * ncol2] = static_cast(in_data[i]); + } +}; + +template +inline void check_solve(const Tensor& A, + const Tensor& B) { + CHECK_EQ(A.size(0), A.size(1)) << "A must bu square matrix"; + CHECK_EQ(A.size(1), B.size(1)) << "A, B have incompatible sizes"; +} + +#define LINALG_CPU_SOLVE(fname, DType) \ +template<> inline \ +void linalg_solve(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + Stream *s) { \ + check_solve(A, X); \ + const int N = X.size(1), nrhs = X.size(0); \ + const int lda = (N == 0 ? 1 : N), ldx = (N == 0 ? 1 : N); \ + int res(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, N, nrhs, \ + A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx)); \ + CHECK_LE(res, 0) << #fname << ": U(" << res << ", " << res \ + << ") is exactly zero. The factorization has been completed," \ + << "but the factor U is exactly singular, so the solution could not be computed."; \ + CHECK_GE(res, 0) << #fname << ": the " << -res \ + << "-th argument had an illegal value"; \ +} +LINALG_CPU_SOLVE(sgesv, float) +LINALG_CPU_SOLVE(dgesv, double) + +#ifdef __CUDACC__ + +#if CUDA_VERSION >= 8000 + +#define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \ +template<> inline \ +int linalg_dn_getrf_workspace_query(const Tensor& A, \ + Stream *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + int lwork(0); \ + CUSOLVER_CALL(cusolver##fname##_bufferSize(Stream::GetSolverHandle(s), \ + A.size(1), A.size(1), A.dptr_, \ + (A.size(1) == 0 ? 1 : A.size(1)), &lwork)); \ + return lwork; \ +} + +#define LINALG_GPU_DN_GETRF(fname, DType) \ +template<> inline \ +void linalg_dn_getrf(const Tensor& A, \ + const Tensor& ipiv, \ + Stream *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \ + const int lwork = linalg_dn_getrf_workspace_query(A, s); \ + Storage::Handle workspace = Storage::Get()->Alloc(sizeof(DType) * lwork, Context::GPU()); \ + CUSOLVER_CALL(cusolver##fname(Stream::GetSolverHandle(s), \ + A.size(1), A.size(1), A.dptr_, (A.size(1) == 0 ? 1 : A.size(1)), \ + static_cast(workspace.dptr), ipiv.dptr_, \ + static_cast(info.dptr))); \ + Storage::Get()->Free(info); \ + Storage::Get()->Free(workspace); \ +} + +#define LINALG_GPU_DN_GETRS(fname, DType) \ +template<> inline \ +void linalg_dn_getrs(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + Stream *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + const int N = A.size(0), nrhs = X.size(0); \ + const int lda = (A.size(1) == 0 ? 1 : A.size(1)), ldx = (X.size(1) == 0 ? 1 : X.size(1)); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \ + CUSOLVER_CALL(cusolver##fname(Stream::GetSolverHandle(s), \ + CUBLAS_OP_N, N, nrhs, \ + A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx, \ + static_cast(info.dptr))); \ + Storage::Get()->Free(info); \ +} + +#define LINALG_GPU_SOLVE(DType) \ +template<> inline \ +void linalg_solve(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + Stream *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + check_solve(A, X); \ + linalg_dn_getrf(A, ipiv, s); \ + linalg_dn_getrs(A, X, ipiv, s); \ +} + +#else // CUDA_VERSION >= 8000 + +#define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \ +template<> inline \ +int linalg_dn_getrf_workspace_query(const Tensor& A, \ + Stream *s) { \ + LOG(FATAL) << "Dn_getrf_workspace_query requires CUDA version >= 8.0!"; \ +} + +#define LINALG_GPU_DN_GETRF(fname, DType) \ +template<> inline \ +void linalg_dn_getrf(const Tensor& A, \ + const Tensor& ipiv, \ + Stream *s) { \ + LOG(FATAL) << "Dn_getrf requires CUDA version >= 8.0!"; \ +} + +#define LINALG_GPU_DN_GETRS(fname, DType) \ +template<> inline \ +void linalg_dn_getrs(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + Stream *s) { \ + LOG(FATAL) << "Dn_getrs requires CUDA version >= 8.0!"; \ +} + +#define LINALG_GPU_SOLVE(DType) \ +template<> inline \ +void linalg_solve(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + Stream *s) { \ + LOG(FATAL) << "gpu solve requires CUDA version >= 8.0!"; \ +} + +#endif // CUDA_VERSION >= 8000 + +LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnSgetrf, float) +LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnDgetrf, double) + +LINALG_GPU_DN_GETRF(DnSgetrf, float) +LINALG_GPU_DN_GETRF(DnDgetrf, double) + +LINALG_GPU_DN_GETRS(DnSgetrs, float) +LINALG_GPU_DN_GETRS(DnDgetrs, double) + +LINALG_GPU_SOLVE(float) +LINALG_GPU_SOLVE(double) + +#endif // __CUDACC__ + +#define LINALG_XPU_BATCH_SOLVE(xpu, DType) \ +template<> inline \ +void linalg_batch_solve(const Tensor& A, \ + const Tensor& X, \ + const Tensor& ipiv, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_solve(A[i], X[i], ipiv[i], s); \ + } \ +} +LINALG_XPU_BATCH_SOLVE(cpu, float) +LINALG_XPU_BATCH_SOLVE(cpu, double) + +#ifdef __CUDACC__ + +LINALG_XPU_BATCH_SOLVE(gpu, float) +LINALG_XPU_BATCH_SOLVE(gpu, double) + +#endif // __CUDACC__ + +struct solve { + template + static void op(const Tensor& A, + const Tensor& X, + const Tensor& ipiv, + const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + linalg_batch_solve(A, X, ipiv, ctx); // ipiv for work_space in Lapacke_#gesv + } +}; + +template +void LaOpForwardSolve(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), inum); + CHECK_EQ(outputs.size(), onum); + CHECK_EQ(req.size(), onum); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Stream *s = ctx.get_stream(); + const mxnet::TBlob& a_tblob = inputs[0]; + const mxnet::TBlob& b_tblob = inputs[1]; + const mxnet::TBlob& x_tblob = outputs[0]; + const mxnet::TShape& a_shape = a_tblob.shape_; + mxnet::TShape b_shape(a_shape.ndim(), 1); + for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; } + if (b_tblob.shape_.ndim() == a_shape.ndim()) { + b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1]; + } + const int ndim = a_shape.ndim(); + mxnet::TShape ipiv_shape(a_shape); + ipiv_shape[ndim - 1] = 1; + if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] || + 0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; } + + const int work_space_size = + sizeof(OType) * (a_shape.Size() + b_shape.Size()) + sizeof(int) * ipiv_shape.Size(); + Tensor work_buffer = + ctx.requested[0].get_space_typed(Shape1(work_space_size), s); + MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, { + // cast type and transpose + mxnet_op::Kernel::Launch( + s, a_shape.Size(), + a_tblob.dptr(), + reinterpret_cast(work_buffer.dptr_), + a_shape[ndim - 1], a_shape[ndim - 2], a_shape[ndim - 1] * a_shape[ndim - 2]); + }); + MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, { + // cast type and transpose + mxnet_op::Kernel::Launch( + s, b_shape.Size(), + b_tblob.dptr(), + reinterpret_cast(work_buffer.dptr_) + a_shape.Size(), + b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]); + }); + // transpose shape + int temp = b_shape[ndim - 1]; + b_shape[ndim - 1] = b_shape[ndim - 2]; + b_shape[ndim - 2] = temp; + mxnet::TBlob a_transpose_tblob(reinterpret_cast(work_buffer.dptr_), + a_shape, a_tblob.dev_mask(), a_tblob.dev_id()); + mxnet::TBlob b_transpose_tblob(reinterpret_cast(work_buffer.dptr_) + a_shape.Size(), + b_shape, b_tblob.dev_mask(), b_tblob.dev_id()); + mxnet::TBlob ipiv_tblob(reinterpret_cast( + reinterpret_cast(work_buffer.dptr_) + a_shape.Size() + b_shape.Size()), + ipiv_shape, b_tblob.dev_mask(), b_tblob.dev_id()); + + laop::op(a_transpose_tblob.FlatToKD(s), + b_transpose_tblob.FlatToKD(s), + ipiv_tblob.FlatToKD(s), + ctx, + attrs); + // X = transpose(B) + mxnet_op::Kernel::Launch( + s, b_shape.Size(), + b_transpose_tblob.dptr(), + x_tblob.dptr(), + b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]); + }); +} + +// X = (inv_A) * B +struct solve_backward { + template + static void op(const Tensor& dX, + const Tensor& inv_A, + const Tensor& B, + const Tensor& X, + const Tensor& dA, + const Tensor& dB, + const OpContext& ctx, + const nnvm::NodeAttrs& attrs) { + // (1) calcualte dB = trans(inv(A)) * dX + // (2) calcualte dA = dB * trans(X) + Stream *s = ctx.get_stream(); + gemm2::op(inv_A, dX, dB, DType(1), true, false, s); + gemm2::op(dB, X, dA, DType(-1), false, true, s); + } +}; + +template +inline void batch_inverse(const Tensor& inv_A, + const Tensor& LU, + const Tensor& pivot, + const mxnet::OpContext& ctx); + +#define CPU_BATCH_INVERSE(xpu, DType) \ +template<> inline \ +void batch_inverse(const Tensor& inv_A, \ + const Tensor& LU, \ + const Tensor& pivot, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + for (index_t i = 0; i < inv_A.size(0); ++i) { \ + linalg_getrf(inv_A[i], pivot[i], true, s); \ + const Tensor work( \ + LU[i].dptr_, Shape1(LU.size(1) * LU.size(2))); \ + linalg_getri(inv_A[i], pivot[i], work, s); \ + } \ +} +CPU_BATCH_INVERSE(cpu, float) +CPU_BATCH_INVERSE(cpu, double) + +#ifdef __CUDACC__ + +// GETRF and GETRI only available with cuda8 or higher. +#if CUDA_VERSION >= 8000 + +#define GPU_BATCH_INVERSE(xpu, DType) \ +template<> inline \ +void batch_inverse(const Tensor& inv_A, \ + const Tensor& LU, \ + const Tensor& pivot, \ + const mxnet::OpContext& ctx) { \ + Stream *s = ctx.get_stream(); \ + if (LU.dptr_ != inv_A.dptr_) Copy(LU, inv_A, s); \ + linalg_batch_getrf(LU, pivot, true, s); \ + linalg_batch_getri(inv_A, LU, pivot, s); \ +} + +#else // CUDA_VERSION >= 8000 + +#define GPU_BATCH_INVERSE(xpu, DType) \ +template<> inline \ +void batch_inverse(const Tensor& inv_A, \ + const Tensor& LU, \ + const Tensor& pivot, \ + const mxnet::OpContext& ctx) { \ + LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \ +} + +#endif // CUDA_VERSION >= 8000 + +GPU_BATCH_INVERSE(gpu, float) +GPU_BATCH_INVERSE(gpu, double) + +#endif // __CUDACC__ + +template +void LaOpBackwardSolve(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), inum); + CHECK_EQ(outputs.size(), onum); + CHECK_EQ(req.size(), onum); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mshadow::Stream *s = ctx.get_stream(); + const mxnet::TBlob& a_tblob = inputs[1]; + const mxnet::TBlob& b_tblob = inputs[2]; + const mxnet::TBlob& x_tblob = inputs[3]; + + const mxnet::TShape& a_shape = a_tblob.shape_; + mxnet::TShape b_shape(a_shape.ndim(), 1); + for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; } + if (b_tblob.shape_.ndim() == a_shape.ndim()) { + b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1]; + } + const int ndim = a_shape.ndim(); + const int N = a_shape[ndim - 1]; + if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] || + 0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; } + + const Tensor A = a_tblob.FlatToKD(s); + int work_space_size = sizeof(OType) * a_shape.Size(); // for inverse(A) + work_space_size += sizeof(OType) * a_shape.Size(); // for getri work space + work_space_size += 2 * sizeof(OType) * b_shape.Size(); // for B and X + work_space_size += sizeof(int) * A.size(0) * N; // for pivot work space + Tensor work_buffer = + ctx.requested[0].get_space_typed(Shape1(work_space_size), s); + + MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, { + mxnet_op::Kernel::Launch( + s, a_shape.Size(), + reinterpret_cast(work_buffer.dptr_), + a_tblob.dptr()); + }); + mxnet::TBlob a_inverse_tblob(reinterpret_cast(work_buffer.dptr_), + a_shape, a_tblob.dev_mask(), a_tblob.dev_id()); + const Tensor inv_A = a_inverse_tblob.FlatToKD(s); + + mxnet::TBlob lu_tblob(reinterpret_cast(work_buffer.dptr_) + a_shape.Size(), + inv_A.shape_, a_tblob.dev_mask(), a_tblob.dev_id()); + const Tensor LU = lu_tblob.FlatToKD(s); + + MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, { + mxnet_op::Kernel::Launch( + s, b_shape.Size(), + reinterpret_cast(work_buffer.dptr_) + 2 * a_shape.Size(), + b_tblob.dptr()); + }); + mxnet::TBlob b_cp_tblob(reinterpret_cast(work_buffer.dptr_) + 2 * a_shape.Size(), + b_shape, b_tblob.dev_mask(), b_tblob.dev_id()); + const Tensor B = b_cp_tblob.FlatToKD(s); + + MSHADOW_TYPE_SWITCH(x_tblob.type_flag_, XType, { + mxnet_op::Kernel::Launch( + s, b_shape.Size(), + reinterpret_cast(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(), + x_tblob.dptr()); + }); + mxnet::TBlob x_cp_tblob( + reinterpret_cast(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(), + b_shape, b_tblob.dev_mask(), b_tblob.dev_id()); + const Tensor X = x_cp_tblob.FlatToKD(s); + + mxnet::TBlob pivot_tblob(reinterpret_cast( + reinterpret_cast(work_buffer.dptr_) + 2 * a_shape.Size() + 2 * b_shape.Size()), + Shape2(A.size(0), N), a_tblob.dev_mask(), a_tblob.dev_id()); + const Tensor pivot = pivot_tblob.FlatToKD(s); + + // calculate inverse(A) on CPU or GPU + batch_inverse(inv_A, LU, pivot, ctx); + laop::op(inputs[0].FlatToKD(s), + inv_A, + B, + X, + outputs[0].FlatToKD(s), + outputs[1].FlatToKD(s), + ctx, + attrs); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_ diff --git a/src/operator/numpy/linalg/np_solve.cc b/src/operator/numpy/linalg/np_solve.cc new file mode 100644 index 000000000000..55d02f18d4dc --- /dev/null +++ b/src/operator/numpy/linalg/np_solve.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_solve.cc + * \brief CPU implementation placeholder of Solve Operator + */ +#include +#include +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../elemwise_op_common.h" + +#include "./np_solve-inl.h" + +namespace mxnet { +namespace op { + +inline bool SolveOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& in_a_shape = (*in_attrs)[0]; + const mxnet::TShape& in_b_shape = (*in_attrs)[1]; + if (!ndim_is_known(in_a_shape)) { return false; } + int in_a_ndim = in_a_shape.ndim(), in_b_ndim = in_b_shape.ndim(); + + CHECK_GE(in_a_ndim, 2) + << "Array must be at least two-dimensional"; + CHECK_EQ(in_a_shape[in_a_ndim - 2], in_a_shape[in_a_ndim - 1]) + << "Input A's last two dimension must be equal"; + + if (in_a_ndim == in_b_ndim + 1) { + CHECK_EQ(in_a_shape[in_a_ndim - 1], in_b_shape[in_b_ndim - 1]) + << "Input A's and B's last dimension must be equal"; + } else if (in_a_ndim == in_b_ndim) { + CHECK_EQ(in_a_shape[in_a_ndim - 1], in_b_shape[in_b_ndim - 2]) + << "Input A's and B's last second dimension must be equal"; + } else { + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "A's and B's dimensions don't match"; + } + for (int i = 0; i < in_a_ndim - 2; ++i) { + CHECK_EQ(in_a_shape[i], in_b_shape[i]) << "A's and B's dimensions don't match"; + } + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_b_shape); + return !mxnet::op::shape_is_none(in_b_shape) && !mxnet::op::shape_is_none(out_attrs->at(0)); +} + +inline bool SolveOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + int a_type = in_attrs->at(0); + int b_type = in_attrs->at(1); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg"; + CHECK_NE(b_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg"; + if (mshadow::kFloat32 == a_type && mshadow::kFloat32 == b_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } + return out_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_npi_solve) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", [](const NodeAttrs& attrs){ + return std::vector{"A", "B"}; +}) +.set_attr("FInferShape", SolveOpShape) +.set_attr("FInferType", SolveOpType) +.set_attr("FResourceRequest", [](const NodeAttrs& attrs){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", LaOpForwardSolve) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_npi_solve"}) +.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix") +.add_argument("B", "NDArray-or-Symbol", "Tensor of right side vector"); + +NNVM_REGISTER_OP(_backward_npi_solve) +.set_num_inputs(4) +.set_num_outputs(2) +.set_attr("FResourceRequest", [](const NodeAttrs& ){ + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("TIsBackward", true) +.set_attr("FCompute", LaOpBackwardSolve); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_solve.cu b/src/operator/numpy/linalg/np_solve.cu new file mode 100644 index 000000000000..b849cf55540e --- /dev/null +++ b/src/operator/numpy/linalg/np_solve.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_solve.cu + * \brief GPU implementation of the Solve Operator + */ + +#include +#include +#include "./np_solve-inl.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(_npi_solve) +.set_attr("FCompute", LaOpForwardSolve); + +NNVM_REGISTER_OP(_backward_npi_solve) +.set_attr("FCompute", LaOpBackwardSolve); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 53bd8f4d9235..6b5efa0c96b0 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -319,6 +319,27 @@ def _add_workload_linalg_inv(): OpArgMngr.add_workload('linalg.inv', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) +def _add_workload_linalg_solve(): + shapes = [(0,0), (1,1), (5,5), (20,20), (3,5,5), (3,0,0), (2,20,20), (0,20,20), (2,3,20,20)] + nrhs = (0, 1, 2, 10) + dtypes = (np.float32, np.float64) + for dtype, shape in itertools.product(dtypes, shapes): + a = _np.random.rand(*shape) + shape_b = list(shape) + shape_b[-1] = 1 + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + shape_b.pop() + b = b.reshape(shape_b) + OpArgMngr.add_workload('linalg.solve', np.array(a, dtype=dtype), np.array(b, dtype=dtype)) + for nrh in nrhs: + shape_b = list(shape) + shape_b[-1] = nrh + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + OpArgMngr.add_workload('linalg.solve', np.array(a, dtype=dtype), np.array(b, dtype=dtype)) + + def _add_workload_linalg_det(): OpArgMngr.add_workload('linalg.det', np.array(_np.ones((2, 2)), dtype=np.float32)) OpArgMngr.add_workload('linalg.det', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) @@ -1374,6 +1395,7 @@ def _prepare_workloads(): _add_workload_linalg_norm() _add_workload_linalg_cholesky() _add_workload_linalg_inv() + _add_workload_linalg_solve() _add_workload_linalg_det() _add_workload_linalg_slogdet() _add_workload_trace() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5c14f0d6c701..6b62a4386524 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3506,6 +3506,97 @@ def check_inv(A_inv, data_np): check_inv(A_inv, data_np) +@with_seed() +@use_np +def test_np_linalg_solve(): + class TestSolve(HybridBlock): + def __init__(self): + super(TestSolve, self).__init__() + + def hybrid_forward(self, F, a, b): + return F.np.linalg.solve(a, b) + + def check_solve(x, a_np, b_np): + try: + x_expected = _np.linalg.solve(a_np, b_np) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print("b", b_np) + print("b shape:", b_np.shape) + print(e) + else: + assert x.shape == x_expected.shape + assert_almost_equal(x.asnumpy(), x_expected, rtol=rtol, atol=atol) + + def get_grad_b(A, X): + dX = _np.ones_like(X) + A_inv = _np.linalg.inv(A) + A_inv_trans = _np.swapaxes(A_inv, -1, -2) + return _np.matmul(A_inv_trans, dX) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (20, 20), + (3, 20, 20), + (1, 0, 0), + (0, 1, 1), + (0, 5, 3, 3), + (5, 0, 0, 0), + (2, 3, 10, 10) + ] + nrhs = (-1, 0, 1, 2, 5) + dtypes = ['float32', 'float64'] + for hybridize, shape, dtype, nrh in itertools.product([False, True], shapes, dtypes, nrhs): + rtol = 1e-3 + atol = 1e-5 + test_solve = TestSolve() + if hybridize: + test_solve.hybridize() + + if 0 in shape: + a = _np.ones(shape) + b = _np.ones(shape) + else: + shape_a = shape + a = _np.random.rand(*shape_a) + shape_b = list(shape_a) + if nrh == -1: + shape_b[-1] = 1 + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + shape_b.pop() + b = b.reshape(shape_b) + else : + shape_b[-1] = nrh + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + a = np.array(a, dtype=dtype) + b = np.array(b, dtype=dtype) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + mx_out = test_solve(a, b) + # check solve validity + assert mx_out.shape == b.shape + check_solve(mx_out, a, b) + + # check backward. backward does not support empty input + if 0 not in mx_out.shape: + if nrh != -1: + mx.autograd.backward(mx_out) + b_backward_expected = get_grad_b(a.asnumpy(), mx_out.asnumpy()) + a_backward_expected = -_np.matmul(b_backward_expected, _np.swapaxes(mx_out, -1, -2).asnumpy()) + assert_almost_equal(a.grad.asnumpy(), a_backward_expected, rtol=rtol, atol=atol) + assert_almost_equal(b.grad.asnumpy(), b_backward_expected, rtol=rtol, atol=atol) + + # check imperative once again + mx_out = np.linalg.solve(a, b) + check_solve(mx_out, a, b) + + @with_seed() @use_np def test_np_linalg_det():