Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the backward support for QR #38824

Merged
merged 2 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions paddle/fluid/operators/qr_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
Expand Down Expand Up @@ -79,9 +80,11 @@ class QrCPUKernel : public framework::OpKernel<T> {
q_data = q.mutable_data<math::Real<T>>(
context.GetPlace(),
size_t(batch_size * m * k * sizeof(math::Real<T>)));
memset(q_data, 0, size_t(batch_size * m * k * sizeof(math::Real<T>)));
}
auto* r_data = r.mutable_data<math::Real<T>>(
context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real<T>)));
memset(r_data, 0, size_t(batch_size * k * n * sizeof(math::Real<T>)));

// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
Expand Down Expand Up @@ -126,8 +129,124 @@ template <typename DeviceContext, typename T>
class QrGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"QR doesn't have the backward kernel now and will be supported soon."));
const framework::Tensor& Q = *ctx.Input<framework::Tensor>("Q");
const framework::Tensor& R = *ctx.Input<framework::Tensor>("R");
// Use a different name A instead of X
const framework::Tensor& A = *ctx.Input<framework::Tensor>("X");
const framework::Tensor& dQ =
*ctx.Input<framework::Tensor>(framework::GradVarName("Q"));
const framework::Tensor& dR =
*ctx.Input<framework::Tensor>(framework::GradVarName("R"));
// Use a different name dA instead of dX
framework::Tensor& dA =
*ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dA.mutable_data<math::Real<T>>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T>()(dev_ctx, &dA, T(0));

auto dito = math::DeviceIndependenceTensorOperations<DeviceContext, T>(ctx);

std::string mode = ctx.Attr<std::string>("mode");
bool compute_q, reduced;
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
if (!compute_q) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='r'."));
}

auto a_dims = A.dims();
int a_rank = a_dims.size();
int m = a_dims[a_rank - 2];
int n = a_dims[a_rank - 1];

if ((m > n) && (!reduced)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The derivative of qr is not implemented when mode='complete' and "
"nrows > ncols."));
}

// m >= n case
auto m_gt_n_case = [](
const framework::ExecutionContext& ctx,
math::DeviceIndependenceTensorOperations<DeviceContext, T>& dito,
const Tensor& dQ, const Tensor& dR, const Tensor& A, const Tensor& Q,
const Tensor& R) -> framework::Tensor {
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable
// Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization

// dR^H
framework::Tensor R_term;
if (ctx.HasInput(framework::GradVarName("R"))) {
R_term = dito.Matmul(R, dito.Transpose(dR));
} else {
R_term = dito.Fill(framework::vectorize<int>(R.dims()), 0);
}

// dQ^H * Q
framework::Tensor Q_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
Q_term = dito.Matmul(dito.Transpose(dQ), Q);
} else {
Q_term = dito.Fill(framework::vectorize<int>(R.dims()), 0);
}

framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term);

// Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true);
framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true);
framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1));

framework::Tensor rhs_term;
if (ctx.HasInput(framework::GradVarName("Q"))) {
rhs_term = dito.Add(dQ, dito.Matmul(Q, M));
} else {
rhs_term = dito.Matmul(Q, M);
}

// dA * R^H = rhs_term
auto dA =
dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))),
dito.Transpose(rhs_term),
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);

return dito.Transpose(dA);
};

if (m >= n) {
auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
} else {
// If m < n for input matrices A, we partition A = [X|Y] and R = [U|V]
// Calculate dX and dY individually and concatenate them to get dA
dA.mutable_data<math::Real<T>>(ctx.GetPlace());

auto Y = dito.Slice(A, {-1}, {m}, {n});
auto U = dito.Slice(R, {-1}, {0}, {m});
framework::Tensor dY, dX, dV, dR_tmp, dQ_prime;

if (ctx.HasInput(framework::GradVarName("R"))) {
dV = dito.Slice(dR, {-1}, {m}, {n});
dR_tmp = dito.Slice(dR, {-1}, {0}, {m});
// Y * dV^H
dQ_prime = dito.Matmul(Y, dito.Transpose(dV));
} else {
dV = dito.Fill(framework::vectorize<int>(Y.dims()), 0);
dQ_prime = dito.Fill(framework::vectorize<int>(Q.dims()), 0);
}

if (ctx.HasInput(framework::GradVarName("Q"))) {
dQ_prime = dito.Add(dQ_prime, dQ);
}
dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U);
dY = dito.Matmul(Q, dV);
// Concatenate dX and dY to get dA.
auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1);
framework::TensorCopy(dA_tmp, dA.place(), &dA);
}
}
};

Expand Down
135 changes: 135 additions & 0 deletions paddle/fluid/operators/svd_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,93 @@ static std::vector<int> GetBroadcastShape(InTensors ins) {
return broadcast_shape;
}

static inline framework::DDim ComputeAndCheckShapeForConcatOp(
const bool is_runtime, const std::vector<framework::DDim>& inputs_dims,
const size_t axis) {
const size_t n = inputs_dims.size();
auto out_dims = inputs_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), out_dims.size(),
platform::errors::InvalidArgument(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i]));
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
if (is_runtime) {
out_dims[axis] += inputs_dims[i][j];
} else {
if (inputs_dims[i][j] == -1 || out_dims[j] == -1) {
out_dims[axis] = -1;
} else {
out_dims[axis] += inputs_dims[i][j];
}
}
} else {
bool check_shape =
is_runtime || (inputs_dims[0][j] > 0 && inputs_dims[i][j] > 0);
if (check_shape) {
// check all shape in run time
PADDLE_ENFORCE_EQ(inputs_dims[0][j], inputs_dims[i][j],
platform::errors::InvalidArgument(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
j, i, inputs_dims[0], i, inputs_dims[i]));
}
if (!is_runtime && out_dims[j] == -1 && inputs_dims[i][j] > 0) {
out_dims[j] = inputs_dims[i][j];
}
}
}
}
return out_dims;
}

static inline int64_t ComputeAxisForConcatOp(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d", -rank,
rank, axis));
if (axis < 0) {
axis = axis + rank;
}
return axis > 0 ? axis : 0;
}

// Prepared for the broadcast operation
static std::vector<int64_t> get_broadcast_batch_portion(
std::vector<int64_t> x, std::vector<int64_t> y) {
size_t size_x = x.size();
size_t size_y = y.size();
size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);

ptrdiff_t i = (ptrdiff_t)size - 1;
for (; i >= 0; --i) {
ptrdiff_t offset = size - i - 1;
ptrdiff_t dim_x = size_x - offset - 1;
ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;

PADDLE_ENFORCE_EQ(
(x_size == y_size || x_size == 1 || y_size == 1), true,
platform::errors::PreconditionNotMet(
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.",
x_size, y_size, i));

batchPortion[i] = x_size != 1 ? x_size : y_size;
}
return batchPortion;
}

#define DITO_TRANSPOSE_RANK_CASE(N) \
case N: { \
math::Transpose<DeviceContext, T, N> trans; \
Expand Down Expand Up @@ -515,6 +602,54 @@ struct DeviceIndependenceTensorOperations {
return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape);
}

framework::Tensor TriangularSolve(const framework::Tensor& x,
const framework::Tensor& y, bool upper,
bool transpose, bool unitriangular) {
framework::AttributeMap attrs;
attrs["upper"] = upper;
attrs["transpose"] = transpose;
attrs["unitriangular"] = unitriangular;
NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}});
auto x_dims = x.dims();
auto y_dims = y.dims();
auto y_dims_n = y_dims.size();
std::vector<int64_t> x_dims_vec =
paddle::framework::vectorize<int64_t>(x_dims);
std::vector<int64_t> y_dims_vec =
paddle::framework::vectorize<int64_t>(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(),
x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(),
y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2],
y_dims_vec[y_dims_n - 1]});
std::vector<int> out_shape(y_broadcast_dims.begin(),
y_broadcast_dims.end());
return CreateOpRunAndReturnTensor("triangular_solve", inputs, attrs,
out_shape);
}

framework::Tensor ConcatTwoTensors(const framework::Tensor& x,
const framework::Tensor& y, int axis) {
framework::AttributeMap attrs;
attrs["axis"] = axis;
std::vector<framework::DDim> inputs_dims({x.dims(), y.dims()});
NameInTensorMap inputs({{"X", {&x, &y}}});
size_t axis_ =
ComputeAxisForConcatOp(static_cast<int64_t>(axis),
static_cast<int64_t>(inputs_dims[0].size()));
framework::DDim out_dims =
ComputeAndCheckShapeForConcatOp(true, inputs_dims, axis_);
if (out_dims[axis_] < 0) {
out_dims[axis_] = -1;
}
std::vector<int> out_shape = framework::vectorize<int>(out_dims);
return CreateOpRunAndReturnTensor("concat", inputs, attrs, out_shape);
}

Tensor Conj(const Tensor& x) {
Tensor out;
auto* out_data = out.mutable_data<T>(x.dims(), context.GetPlace());
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_svd_op PROPERTIES TIMEOUT 80)
set_tests_properties(test_qr_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120)
set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120)
Expand Down
Loading