From 0b9aa7c004cc0a9e82040e9fc25c010dd4c759f9 Mon Sep 17 00:00:00 2001 From: caozhou Date: Tue, 22 Mar 2022 07:56:33 +0000 Subject: [PATCH] add parse qr mode func --- paddle/fluid/operators/multiplex_op.cc | 6 ++-- paddle/fluid/operators/qr_op.cc | 4 +-- paddle/fluid/operators/tril_triu_op.cc | 7 ++-- paddle/phi/infermeta/unary.cc | 4 +-- paddle/phi/kernels/cpu/qr_kernel.cc | 24 ++------------ paddle/phi/kernels/funcs/parse_qr_mode.h | 41 ++++++++++++++++++++++++ 6 files changed, 53 insertions(+), 33 deletions(-) create mode 100644 paddle/phi/kernels/funcs/parse_qr_mode.h diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index fa2f02753fee4..4e6ad35e612b7 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" + #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/multiary.h" @@ -131,8 +131,8 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -DELCARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor, - PT_INFER_META(phi::MultiplexInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor, + PD_INFER_META(phi::MultiplexInferMeta)); REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, ops::MultiplexGradMaker, diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index 21e639b635adc..02d5e5f03f02e 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -103,8 +103,8 @@ class QrGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -DELCARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor, - PT_INFER_META(phi::QrInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor, + PD_INFER_META(phi::QrInferMeta)); REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, ops::QrGradMaker, diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index 93b7b12074bae..b941fa3d03ae1 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/framework/op_registry.h" - #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" + #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -90,8 +89,8 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; -DELCARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor, - PT_INFER_META(phi::TrilTriuInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor, + PD_INFER_META(phi::TrilTriuInferMeta)); REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker, ops::TrilTriuGradOpMaker, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9a6fd1c1c47c3..51b77d92b7bfa 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -22,9 +22,9 @@ limitations under the License. */ #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" -#include "paddle/phi/kernels/qr_kernel.h" namespace phi { @@ -167,7 +167,7 @@ void QrInferMeta(const MetaTensor& x, int m = x_dims[x_rank - 2]; int n = x_dims[x_rank - 1]; int min_mn = std::min(m, n); - std::tie(compute_q, reduced_mode) = phi::ParseQrMode(mode); + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); if (compute_q) { int k = reduced_mode ? min_mn : m; diff --git a/paddle/phi/kernels/cpu/qr_kernel.cc b/paddle/phi/kernels/cpu/qr_kernel.cc index e2e32567441ae..b0e82cedb6b8b 100644 --- a/paddle/phi/kernels/cpu/qr_kernel.cc +++ b/paddle/phi/kernels/cpu/qr_kernel.cc @@ -19,30 +19,10 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" namespace phi { -static inline std::tuple ParseQrMode(const std::string& mode) { - bool compute_q; - bool reduced; - if (mode == "reduced") { - compute_q = true; - reduced = true; - } else if (mode == "complete") { - compute_q = true; - reduced = false; - } else if (mode == "r") { - compute_q = false; - reduced = true; - } else { - PADDLE_THROW(errors::InvalidArgument( - "QR received unrecognized mode '%s'" - " but expected one of 'reduced' (default), 'r', or 'complete'", - mode)); - } - return std::make_tuple(compute_q, reduced); -} - template void QrKernel(const Context& ctx, const DenseTensor& x, @@ -51,7 +31,7 @@ void QrKernel(const Context& ctx, DenseTensor* r) { bool compute_q; bool reduced_mode; - std::tie(compute_q, reduced_mode) = ParseQrMode(mode); + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); auto numel = x.numel(); PADDLE_ENFORCE_GT( numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); diff --git a/paddle/phi/kernels/funcs/parse_qr_mode.h b/paddle/phi/kernels/funcs/parse_qr_mode.h new file mode 100644 index 0000000000000..adf64759d3ad6 --- /dev/null +++ b/paddle/phi/kernels/funcs/parse_qr_mode.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +static inline std::tuple ParseQrMode(const std::string& mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; + } else { + PADDLE_THROW(errors::InvalidArgument( + "QR received unrecognized mode '%s'" + " but expected one of 'reduced' (default), 'r', or 'complete'", + mode)); + } + return std::make_tuple(compute_q, reduced); +} +} // namespace funcs +} // namespace phi