-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new API cholesky_solve #38167
Add new API cholesky_solve #38167
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/cholesky_solve_op.h" | ||
#include "paddle/fluid/operators/solve_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddComment(R"DOC(Solves a linear system of equations with a positive " | ||
"semidefinite matrix to be inverted given its Cholesky factor matrix uu." | ||
")DOC"); | ||
AddInput("X", "(Tensor) The input tensor, shape of (*,m,k)"); | ||
AddInput("Y", | ||
"(Tensor) The input tensor, shape of (*,m,m) composed of upper or " | ||
"lower triangular Cholesky factor"); | ||
AddOutput("Out", "(Tensor) The output tensor, shape same to X"); | ||
AddAttr<bool>("upper", | ||
"whether to consider the Cholesky factor " | ||
"as a lower or upper triangular matrix") | ||
.SetDefault(false); | ||
} | ||
}; | ||
|
||
class CholeskySolveOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
void InferShape(framework::InferShapeContext *context) const override { | ||
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "CholeskySolve"); | ||
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "CholeskySolve"); | ||
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "CholeskySolve"); | ||
auto u_dims = context->GetInputDim("Y"); | ||
auto b_dims = context->GetInputDim("X"); | ||
int u_rank = u_dims.size(); | ||
int b_rank = b_dims.size(); | ||
PADDLE_ENFORCE_GE(u_rank, 2, | ||
platform::errors::InvalidArgument( | ||
"the rank of input Y must greater or equal to 2")); | ||
PADDLE_ENFORCE_GE(b_rank, 2, | ||
platform::errors::InvalidArgument( | ||
"the rank of input X must greater or equal to 2")); | ||
PADDLE_ENFORCE_EQ(u_dims[u_rank - 1], u_dims[u_rank - 2], | ||
platform::errors::InvalidArgument( | ||
"input Matrix Y should be square matrix," | ||
"But Got last shape of %ld x %ld", | ||
u_dims[u_rank - 1], u_dims[u_rank - 2])); | ||
PADDLE_ENFORCE_EQ( | ||
b_dims[b_rank - 2], u_dims[u_rank - 2], | ||
platform::errors::InvalidArgument( | ||
"the first dim of input X must equal to the dim of input Y," | ||
"But Got %ld and %ld", | ||
b_dims[b_rank - 2], u_dims[u_rank - 2])); | ||
|
||
std::vector<int64_t> u_dims_vec = paddle::framework::vectorize(u_dims); | ||
std::vector<int64_t> b_dims_vec = paddle::framework::vectorize(b_dims); | ||
|
||
std::vector<int64_t> u_dims_vec_cut(u_dims_vec.begin(), | ||
u_dims_vec.end() - 2); | ||
std::vector<int64_t> b_dims_vec_cut(b_dims_vec.begin(), | ||
b_dims_vec.end() - 2); | ||
|
||
std::vector<int64_t> expand_batch_portion = | ||
get_broadcast_batch_portion(u_dims_vec_cut, b_dims_vec_cut); | ||
|
||
std::vector<int64_t> b_broadcast_dims({expand_batch_portion}); | ||
b_broadcast_dims.insert(b_broadcast_dims.end(), | ||
{b_dims_vec[b_rank - 2], b_dims_vec[b_rank - 1]}); | ||
|
||
// dim of 'Out' is the same with 'Y' after broadcast | ||
context->SetOutputDim("Out", framework::make_ddim(b_broadcast_dims)); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class CholeskySolveOpVarTypeInference : public framework::VarTypeInference { | ||
public: | ||
void operator()(framework::InferVarTypeContext *ctx) const override { | ||
auto var_type = ctx->GetInputType("Y", 0); | ||
auto data_type = ctx->GetInputDataType("Y", 0); | ||
|
||
ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); | ||
ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class CholeskySolveOpGradMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> retv) const override { | ||
retv->SetType("cholesky_solve_grad"); | ||
retv->SetInput("X", this->Input("X")); | ||
retv->SetInput("Y", this->Input("Y")); | ||
retv->SetInput("Out", this->Output("Out")); | ||
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
|
||
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); | ||
retv->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
class CholeskySolveGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cholesky_solve"); | ||
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "cholesky_solve"); | ||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "cholesky_solve"); | ||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
"Out@GRAD", "cholesky_solve"); | ||
|
||
auto x_dims = ctx->GetInputDim("X"); | ||
auto y_dims = ctx->GetInputDim("Y"); | ||
|
||
auto x_grad_name = framework::GradVarName("X"); | ||
auto y_grad_name = framework::GradVarName("Y"); | ||
|
||
if (ctx->HasOutput(x_grad_name)) { | ||
ctx->SetOutputDim(x_grad_name, x_dims); | ||
} | ||
if (ctx->HasOutput(y_grad_name)) { | ||
ctx->SetOutputDim(y_grad_name, y_dims); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(cholesky_solve, ops::CholeskySolveOp, | ||
ops::CholeskySolveOpMaker, | ||
ops::CholeskySolveOpVarTypeInference, | ||
ops::CholeskySolveOpGradMaker<paddle::framework::OpDesc>, | ||
ops::CholeskySolveOpGradMaker<paddle::imperative::OpBase>); | ||
|
||
REGISTER_OPERATOR(cholesky_solve_grad, ops::CholeskySolveGradOp); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
cholesky_solve, | ||
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, float>, | ||
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, double>); | ||
|
||
REGISTER_OP_CPU_KERNEL( | ||
cholesky_solve_grad, | ||
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, float>, | ||
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, double>); | ||
Comment on lines
+169
to
+170
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. register kernel of platform::complex<float> and platform::complex<double> ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。 |
||
// Complex<> is not supported because of TensorExpand, which used to boardcast | ||
// input Tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#ifndef PADDLE_WITH_HIP | ||
// HIP not support cusolver | ||
|
||
#include "paddle/fluid/memory/memory.h" | ||
#include "paddle/fluid/operators/cholesky_solve_op.h" | ||
#include "paddle/fluid/platform/dynload/cusolver.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
using CUDADeviceContext = paddle::platform::CUDADeviceContext; | ||
|
||
template <typename T> | ||
void cusolver_potrs(const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, | ||
int n, int nrhs, T *Adata, int lda, T *Bdata, int ldb, | ||
int *devInfo); | ||
|
||
template <> | ||
void cusolver_potrs<float>(const cusolverDnHandle_t &cusolverH, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we also support cusolver_potrs<platform::complex<float>> and cusolver_potrs<platform::complex<double>>? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。 |
||
cublasFillMode_t uplo, int n, int nrhs, float *Adata, | ||
int lda, float *Bdata, int ldb, int *devInfo) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSpotrs( | ||
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); | ||
} | ||
|
||
template <> | ||
void cusolver_potrs<double>(const cusolverDnHandle_t &cusolverH, | ||
cublasFillMode_t uplo, int n, int nrhs, | ||
double *Adata, int lda, double *Bdata, int ldb, | ||
int *devInfo) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDpotrs( | ||
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); | ||
} | ||
|
||
template <> | ||
void cusolver_potrs<platform::complex<float>>( | ||
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, | ||
platform::complex<float> *Adata, int lda, platform::complex<float> *Bdata, | ||
int ldb, int *devInfo) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnCpotrs( | ||
cusolverH, uplo, n, nrhs, reinterpret_cast<const cuComplex *>(Adata), lda, | ||
reinterpret_cast<cuComplex *>(Bdata), ldb, devInfo)); | ||
} | ||
|
||
template <> | ||
void cusolver_potrs<platform::complex<double>>( | ||
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, | ||
platform::complex<double> *Adata, int lda, platform::complex<double> *Bdata, | ||
int ldb, int *devInfo) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnZpotrs( | ||
cusolverH, uplo, n, nrhs, | ||
reinterpret_cast<const cuDoubleComplex *>(Adata), lda, | ||
reinterpret_cast<cuDoubleComplex *>(Bdata), ldb, devInfo)); | ||
} | ||
|
||
template <typename T> | ||
class CholeskySolveFunctor<paddle::platform::CUDADeviceContext, T> { | ||
public: | ||
void operator()(const platform::CUDADeviceContext &dev_ctx, bool upper, int n, | ||
int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { | ||
cublasFillMode_t uplo = | ||
upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; | ||
|
||
/* step 1: get cusolver handle*/ | ||
auto cusolverH = dev_ctx.cusolver_dn_handle(); | ||
|
||
/* step 2: solve A0*X0 = B0 */ | ||
cusolver_potrs<T>(cusolverH, uplo, n, nrhs, Adata, lda, Bdata, lda, | ||
devInfo); | ||
|
||
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class MatrixReduceSumFunctor<platform::CUDADeviceContext, T> { | ||
public: | ||
void operator()(const Tensor &in, Tensor *out, | ||
const framework::ExecutionContext &ctx) { | ||
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] | ||
// out_reduce_dim should be [0, 2] | ||
const std::vector<std::int64_t> in_dims = framework::vectorize(in.dims()); | ||
auto in_size = in_dims.size(); | ||
const std::vector<std::int64_t> out_dims = | ||
framework::vectorize(out->dims()); | ||
auto out_size = out_dims.size(); | ||
|
||
std::vector<std::int64_t> out_bst_dims(in_size); | ||
|
||
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); | ||
std::copy(out_dims.data(), out_dims.data() + out_size, | ||
out_bst_dims.data() + in_size - out_size); | ||
|
||
std::vector<int> out_reduce_dims; | ||
for (size_t idx = 0; idx <= in_size - 3; idx++) { | ||
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { | ||
out_reduce_dims.push_back(idx); | ||
} | ||
} | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
in, out, kps::IdentityFunctor<T>(), out_reduce_dims, stream); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
cholesky_solve, | ||
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
cholesky_solve_grad, | ||
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, double>); | ||
|
||
#endif // not PADDLE_WITH_HIP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
register kernel of platform::complex<float> and platform::complex<double> ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。