Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new API cholesky_solve #38167

Merged
merged 1 commit into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ function(op_library TARGET)
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu")
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
list(REMOVE_ITEM hip_srcs "svd_op.cu")
list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
Expand Down
172 changes: 172 additions & 0 deletions paddle/fluid/operators/cholesky_solve_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/operators/solve_op.h"

namespace paddle {
namespace operators {

class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(Solves a linear system of equations with a positive "
"semidefinite matrix to be inverted given its Cholesky factor matrix uu."
")DOC");
AddInput("X", "(Tensor) The input tensor, shape of (*,m,k)");
AddInput("Y",
"(Tensor) The input tensor, shape of (*,m,m) composed of upper or "
"lower triangular Cholesky factor");
AddOutput("Out", "(Tensor) The output tensor, shape same to X");
AddAttr<bool>("upper",
"whether to consider the Cholesky factor "
"as a lower or upper triangular matrix")
.SetDefault(false);
}
};

class CholeskySolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "CholeskySolve");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "CholeskySolve");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "CholeskySolve");
auto u_dims = context->GetInputDim("Y");
auto b_dims = context->GetInputDim("X");
int u_rank = u_dims.size();
int b_rank = b_dims.size();
PADDLE_ENFORCE_GE(u_rank, 2,
platform::errors::InvalidArgument(
"the rank of input Y must greater or equal to 2"));
PADDLE_ENFORCE_GE(b_rank, 2,
platform::errors::InvalidArgument(
"the rank of input X must greater or equal to 2"));
PADDLE_ENFORCE_EQ(u_dims[u_rank - 1], u_dims[u_rank - 2],
platform::errors::InvalidArgument(
"input Matrix Y should be square matrix,"
"But Got last shape of %ld x %ld",
u_dims[u_rank - 1], u_dims[u_rank - 2]));
PADDLE_ENFORCE_EQ(
b_dims[b_rank - 2], u_dims[u_rank - 2],
platform::errors::InvalidArgument(
"the first dim of input X must equal to the dim of input Y,"
"But Got %ld and %ld",
b_dims[b_rank - 2], u_dims[u_rank - 2]));

std::vector<int64_t> u_dims_vec = paddle::framework::vectorize(u_dims);
std::vector<int64_t> b_dims_vec = paddle::framework::vectorize(b_dims);

std::vector<int64_t> u_dims_vec_cut(u_dims_vec.begin(),
u_dims_vec.end() - 2);
std::vector<int64_t> b_dims_vec_cut(b_dims_vec.begin(),
b_dims_vec.end() - 2);

std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(u_dims_vec_cut, b_dims_vec_cut);

std::vector<int64_t> b_broadcast_dims({expand_batch_portion});
b_broadcast_dims.insert(b_broadcast_dims.end(),
{b_dims_vec[b_rank - 2], b_dims_vec[b_rank - 1]});

// dim of 'Out' is the same with 'Y' after broadcast
context->SetOutputDim("Out", framework::make_ddim(b_broadcast_dims));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace());
}
};

class CholeskySolveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = ctx->GetInputType("Y", 0);
auto data_type = ctx->GetInputDataType("Y", 0);

ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS);
}
};

template <typename T>
class CholeskySolveOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("cholesky_solve_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));

retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};

class CholeskySolveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cholesky_solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "cholesky_solve");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "cholesky_solve");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "cholesky_solve");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");

auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");

if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};

} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cholesky_solve, ops::CholeskySolveOp,
ops::CholeskySolveOpMaker,
ops::CholeskySolveOpVarTypeInference,
ops::CholeskySolveOpGradMaker<paddle::framework::OpDesc>,
ops::CholeskySolveOpGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(cholesky_solve_grad, ops::CholeskySolveGradOp);

REGISTER_OP_CPU_KERNEL(
cholesky_solve,
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskySolveKernel<paddle::platform::CPUDeviceContext, double>);
Comment on lines +164 to +165
Copy link
Contributor

@jeff41404 jeff41404 Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register kernel of platform::complex<float> and platform::complex<double> ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。


REGISTER_OP_CPU_KERNEL(
cholesky_solve_grad,
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskySolveGradKernel<paddle::platform::CPUDeviceContext, double>);
Comment on lines +169 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register kernel of platform::complex<float> and platform::complex<double> ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。

// Complex<> is not supported because of TensorExpand, which used to boardcast
// input Tensor
136 changes: 136 additions & 0 deletions paddle/fluid/operators/cholesky_solve_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_WITH_HIP
// HIP not support cusolver

#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;

template <typename T>
void cusolver_potrs(const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo,
int n, int nrhs, T *Adata, int lda, T *Bdata, int ldb,
int *devInfo);

template <>
void cusolver_potrs<float>(const cusolverDnHandle_t &cusolverH,
Copy link
Contributor

@jeff41404 jeff41404 Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also support cusolver_potrs<platform::complex<float>> and cusolver_potrs<platform::complex<double>>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用于boardcast 输入Tensor的TensorExpand操作暂不支持复数,所以先不注册复数类型。底层代码逻辑实现考虑了复数情况。

cublasFillMode_t uplo, int n, int nrhs, float *Adata,
int lda, float *Bdata, int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSpotrs(
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo));
}

template <>
void cusolver_potrs<double>(const cusolverDnHandle_t &cusolverH,
cublasFillMode_t uplo, int n, int nrhs,
double *Adata, int lda, double *Bdata, int ldb,
int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDpotrs(
cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo));
}

template <>
void cusolver_potrs<platform::complex<float>>(
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs,
platform::complex<float> *Adata, int lda, platform::complex<float> *Bdata,
int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnCpotrs(
cusolverH, uplo, n, nrhs, reinterpret_cast<const cuComplex *>(Adata), lda,
reinterpret_cast<cuComplex *>(Bdata), ldb, devInfo));
}

template <>
void cusolver_potrs<platform::complex<double>>(
const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs,
platform::complex<double> *Adata, int lda, platform::complex<double> *Bdata,
int ldb, int *devInfo) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnZpotrs(
cusolverH, uplo, n, nrhs,
reinterpret_cast<const cuDoubleComplex *>(Adata), lda,
reinterpret_cast<cuDoubleComplex *>(Bdata), ldb, devInfo));
}

template <typename T>
class CholeskySolveFunctor<paddle::platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext &dev_ctx, bool upper, int n,
int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) {
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;

/* step 1: get cusolver handle*/
auto cusolverH = dev_ctx.cusolver_dn_handle();

/* step 2: solve A0*X0 = B0 */
cusolver_potrs<T>(cusolverH, uplo, n, nrhs, Adata, lda, Bdata, lda,
devInfo);

PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
};

template <typename T>
class MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const Tensor &in, Tensor *out,
const framework::ExecutionContext &ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = framework::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims =
framework::vectorize(out->dims());
auto out_size = out_dims.size();

std::vector<std::int64_t> out_bst_dims(in_size);

std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);

std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
in, out, kps::IdentityFunctor<T>(), out_reduce_dims, stream);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cholesky_solve,
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskySolveKernel<paddle::platform::CUDADeviceContext, double>);

REGISTER_OP_CUDA_KERNEL(
cholesky_solve_grad,
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskySolveGradKernel<paddle::platform::CUDADeviceContext, double>);

#endif // not PADDLE_WITH_HIP
Loading