Skip to content

Commit

Permalink
move elementwise mul grad (#40252)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanRisheng authored Mar 9, 2022
1 parent 0604df9 commit 452c75b
Show file tree
Hide file tree
Showing 12 changed files with 539 additions and 401 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ USE_OP(matmul_grad);
USE_OP(square);
USE_OP(transpose2_grad);
USE_OP(concat_grad);
USE_OP(elementwise_mul_grad);
USE_OP_ITSELF(elementwise_mul_grad);
USE_OP(sigmoid_grad);
USE_OP(tanh_grad);
USE_OP(sum);
Expand Down
41 changes: 0 additions & 41 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,47 +196,6 @@ struct MinGradXYFunctor {
}
};

template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
};
template <typename T>
struct MulGradFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T> a,
const Complex<T> b) const {
Complex<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};

template <typename InT, typename OutT>
struct MulGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT a, const InT b,
const InT c) {
phi::Array<OutT, 2> outs;
// dx = dout * y
outs[0] = a * b;
// dy = dout * x
outs[1] = a * c;
return outs;
}
};

template <typename InT, typename OutT>
struct MulGradXYFunctor<Complex<InT>, Complex<OutT>> {
inline HOSTDEVICE phi::Array<Complex<OutT>, 2> operator()(
const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
phi::Array<Complex<OutT>, 2> outs;
// dx = dout * y
Complex<InT> b_conj(b.real, -b.imag);
outs[0] = a * b_conj;
// dy = dout * x
Complex<InT> c_conj(c.real, -c.imag);
outs[1] = a * c_conj;
return outs;
}
};

// Ternary compare
template <typename T>
struct MaxGradXFunctor {
Expand Down
49 changes: 0 additions & 49 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
Expand Down
68 changes: 0 additions & 68 deletions paddle/fluid/operators/elementwise/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,33 +63,6 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
}
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMulGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();

if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {dout, y, x};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, MulGradFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {dout, x};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dy, MulGradFunctor<T>());
}
}

} // namespace operators
} // namespace paddle

Expand All @@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::bfloat16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
Loading

0 comments on commit 452c75b

Please sign in to comment.