Skip to content

Commit

Permalink
Revert "Replace EigenBroadcast with ElementwiseBroadcast in ReduceGrad (
Browse files Browse the repository at this point in the history
#38959)"

This reverts commit 9059ef6.
  • Loading branch information
AnnaTrainingG authored Jan 25, 2022
1 parent 8bb509d commit af0eca9
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 88 deletions.
10 changes: 8 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@

template <typename T>
using CUDAReduceMeanGradKernel =
ops::ReduceCudaGradKernel<T, kps::DivideFunctor>;
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::MeanGradFunctor, true>;

using FP16CUDAReduceMeanGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16, ops::FP16MeanGradFunctor,
true>;

REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
CUDAReduceMeanGradKernel<paddle::platform::float16>,
FP16CUDAReduceMeanGradKernel,
CUDAReduceMeanGradKernel<float>,
CUDAReduceMeanGradKernel<double>);
58 changes: 4 additions & 54 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,11 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int out_dtype = ctx.Attr<int>("out_dtype");
int in_dtype = ctx.Attr<int>("in_dtype");
auto input_data_type =
(out_dtype >= 0)
? static_cast<framework::proto::VarType::Type>(out_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
(in_dtype >= 0) ? static_cast<framework::proto::VarType::Type>(in_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
Expand Down Expand Up @@ -737,55 +736,6 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
pt_out.get());
}
};

template <typename T, template <typename, typename> class TransformOp>
class ReduceCudaGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
bool reduce_all = context.Attr<bool>("reduce_all");
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
auto* in_x = context.Input<Tensor>("X");
auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto out_dtype = context.Attr<int>("in_dtype");
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
std::vector<int> reduce_dims = GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}
// make new tensor
framework::Tensor new_d_out(d_out->type());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(paddle::framework::make_ddim(update_dims));
auto& dev_ctx = context.cuda_device_context();
if (out_dtype > 0) {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
} else {
d_x->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(new_d_out);
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
if (out_dtype <= 0) {
pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(d_out->type()));
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
pten::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx, pt_d_out.get(), pt_d_x.get(), pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
}
};
#endif

} // namespace operators
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int in_dtype = ctx.Attr<int>("out_dtype");
int in_dtype = ctx.Attr<int>("in_dtype");
if (in_dtype >= 0) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(in_dtype),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/reduce_ops/reduce_sum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class ReduceSumGradKernel : public framework::OpKernel<T> {
auto dims = context.Attr<std::vector<int>>("dim");
if (context.GetPlace().GetType() == platform::CPUPlace().GetType() &&
dims.size() == 1) {
int in_dtype = context.Attr<int>("out_dtype");
int in_dtype = context.Attr<int>("in_dtype");

if (in_dtype >= 0) {
Tensor tmp_tensor;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

template <typename T>
using CUDAReduceSumGradKernel =
ops::ReduceCudaGradKernel<T, kps::IdentityFunctor>;
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;

REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
Expand Down
13 changes: 3 additions & 10 deletions paddle/pten/kernels/gpu/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,12 @@ struct DimensionsTransform {
explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
const pten::framework::DDim &dims,
int axis) {
const int N = max(static_cast<int>(ins.size()), 2);
const int N = ins.size();
dim_size = dims.size();
out_dims = pten::framework::vectorize<int64_t>(dims);
in_dims.resize(N);
if (ins.size() == 1) {
// when ins.size() = 1, broadcast input to output
in_dims[0] = pten::framework::vectorize<int64_t>(ins[0]->dims());
in_dims[1] = out_dims;
// Add out_dims to in_dims to avoid errors in dims merging
} else {
for (int j = 0; j < N; ++j) {
in_dims[j] = pten::framework::vectorize<int64_t>(ins[j]->dims());
}
for (int j = 0; j < N; ++j) {
in_dims[j] = pten::framework::vectorize<int64_t>(ins[j]->dims());
}
InputDimensionsExtend(N, axis);

Expand Down
21 changes: 2 additions & 19 deletions paddle/pten/kernels/gpu/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ namespace cub = hipcub;
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/gpu/elementwise.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"

// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
#define REDUCE_VEC_SIZE 4
Expand Down Expand Up @@ -1253,24 +1254,6 @@ void Reduce(const GPUContext& dev_ctx,
x, out, TransformOp<T, MPType>(reduce_num), reduce_dims, stream);
}
}

template <typename InT, typename Functor>
void ReduceGrad(const GPUContext& dev_ctx,
DenseTensor* d_out,
DenseTensor* d_x,
DataType out_dtype,
Functor functor) {
std::vector<const DenseTensor*> inputs = {d_out};
std::vector<DenseTensor*> outputs = {d_x};
PD_VISIT_ALL_TYPES(
out_dtype, "LaunchBroadcastElementwiseCudaKernel", ([&] {
LaunchBroadcastElementwiseCudaKernel<pten::ElementwiseType::kUnary,
InT,
data_t>(
dev_ctx, inputs, &outputs, 0, functor);
}));
}

} // namespace pten

#endif

1 comment on commit af0eca9

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.