Skip to content

Commit

Permalink
[Phi] Move mean op kernel into phi (#40872)
Browse files Browse the repository at this point in the history
* add mean phi kernel

* remove original mean kernel

* add alias name
  • Loading branch information
chenwhql authored Mar 24, 2022
1 parent 6d3db9c commit 8df9176
Show file tree
Hide file tree
Showing 15 changed files with 330 additions and 214 deletions.
21 changes: 2 additions & 19 deletions paddle/fluid/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mean_op.h"
#include <memory>
#include <string>
#include <unordered_map>

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -94,21 +95,3 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::MeanGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
119 changes: 0 additions & 119 deletions paddle/fluid/operators/mean_op.cu

This file was deleted.

70 changes: 0 additions & 70 deletions paddle/fluid/operators/mean_op.h

This file was deleted.

4 changes: 3 additions & 1 deletion paddle/fluid/operators/mean_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class MeanMLUKernel : public framework::OpKernel<T> {
public:
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/mean_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class MeanNPUKernel : public framework::OpKernel<T> {
public:
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/mean_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mean_op.h"
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <unordered_map>

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class MeanXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"

namespace paddle {
namespace operators {
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/core/compat/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"matmul",
"matmul_grad",
"matmul_grad_grad",
"mean",
"mean_grad",
"max",
"max_grad",
"min",
Expand Down
51 changes: 51 additions & 0 deletions paddle/phi/kernels/cpu/mean_all_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/mean_all_grad_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

template <typename T, typename Context>
void MeanAllGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(out_grad.numel(),
1UL,
phi::errors::InvalidArgument(
"Mean Gradient should be scalar. But received "
"Out@Grad's elements num is %d.",
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);

T ig_size = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) =
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast);
}

} // namespace phi

PD_REGISTER_KERNEL(mean_all_grad,
CPU,
ALL_LAYOUT,
phi::MeanAllGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
45 changes: 45 additions & 0 deletions paddle/phi/kernels/cpu/mean_all_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/mean_all_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

template <typename T, typename Context>
void MeanAllKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);

auto X = EigenVector<T>::Flatten(x);
auto y = EigenScalar<T>::From(*out);
auto& place = *dev_ctx.eigen_device();

y.device(place) = X.mean();
}

} // namespace phi

PD_REGISTER_KERNEL(mean_all,
CPU,
ALL_LAYOUT,
phi::MeanAllKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
Loading

0 comments on commit 8df9176

Please sign in to comment.