Skip to content

Commit

Permalink
add parse qr mode func
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Mar 22, 2022
1 parent 0f367d8 commit 0b9aa7c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 33 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/operators/multiplex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ limitations under the License. */

#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

Expand Down Expand Up @@ -131,8 +131,8 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor,
PT_INFER_META(phi::MultiplexInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor,
PD_INFER_META(phi::MultiplexInferMeta));

REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/qr_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class QrGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor,
PT_INFER_META(phi::QrInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor,
PD_INFER_META(phi::QrInferMeta));

REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker,
ops::QrGradMaker<paddle::framework::OpDesc>,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/tril_triu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"

#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

Expand Down Expand Up @@ -90,8 +89,8 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;
namespace plat = paddle::platform;
DELCARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor,
PT_INFER_META(phi::TrilTriuInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor,
PD_INFER_META(phi::TrilTriuInferMeta));
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/qr_kernel.h"

namespace phi {

Expand Down Expand Up @@ -167,7 +167,7 @@ void QrInferMeta(const MetaTensor& x,
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::tie(compute_q, reduced_mode) = phi::ParseQrMode(mode);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);

if (compute_q) {
int k = reduced_mode ? min_mn : m;
Expand Down
24 changes: 2 additions & 22 deletions paddle/phi/kernels/cpu/qr_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,10 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"

namespace phi {

static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}

template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
Expand All @@ -51,7 +31,7 @@ void QrKernel(const Context& ctx,
DenseTensor* r) {
bool compute_q;
bool reduced_mode;
std::tie(compute_q, reduced_mode) = ParseQrMode(mode);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
Expand Down
41 changes: 41 additions & 0 deletions paddle/phi/kernels/funcs/parse_qr_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace phi {
namespace funcs {

static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
} // namespace funcs
} // namespace phi

0 comments on commit 0b9aa7c

Please sign in to comment.