-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize where_op and abs_grad_op by the elementwise interface #39609
Conversation
Thanks for your contribution! |
6564ee3
to
f376b92
Compare
f376b92
to
0a79619
Compare
@@ -17,9 +17,33 @@ | |||
#include "paddle/fluid/platform/for_range.h" | |||
#include "paddle/phi/kernels/abs_grad_kernel.h" | |||
#include "paddle/phi/kernels/funcs/complex_functors.h" | |||
#if defined(__NVCC__) || defined(__HIPCC__) | |||
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi下不能include fluid路径下的文件,参考cast 修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR描述里面介绍清楚一点做的工作,比如:添加哪些functor,调用哪个Kernel等
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto functor = CondFunctor<T>(); | ||
std::vector<const framework::Tensor*> ins = {condition, X, Y}; | ||
std::vector<framework::Tensor*> outs = {out}; | ||
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议改成phi::funcs的那种调用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>( | ||
numel, cond_data, x_data, y_data, out_data); | ||
auto functor = CondFunctor<T>(); | ||
std::vector<const framework::Tensor*> ins = {condition, X, Y}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
相关的framework Tensor后续可以改成DensorTensor
@@ -20,6 +21,15 @@ namespace platform = paddle::platform; | |||
namespace paddle { | |||
namespace operators { | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR里每个函数加上功能说明
return cond ? x : y; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数可以删除了?
@@ -20,6 +21,15 @@ namespace platform = paddle::platform; | |||
namespace paddle { | |||
namespace operators { | |||
|
|||
template <typename T> | |||
struct CondFunctor { | |||
HOSTDEVICE inline CondFunctor() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认构造函数,可以不用显式写。
@@ -154,6 +154,53 @@ struct AbsFunctor<T, NoComplex<T, Real<T>>> { | |||
int64_t numel_; | |||
}; | |||
|
|||
template <typename T> | |||
struct AbsGradCUDAFunctor { | |||
HOSTDEVICE inline AbsGradCUDAFunctor() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认构造函数可以不用显式定义。
}; | ||
|
||
template <> | ||
struct AbsGradCUDAFunctor<phi::dtype::complex<float>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functor定义可以简化下,参考:
Paddle/paddle/phi/kernels/gpu/abs_kernel.cu
Lines 29 to 34 in bbe441f
template <typename T> | |
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> { | |
__device__ __forceinline__ phi::funcs::Real<T> operator()(const T x) const { | |
return abs(x); | |
} | |
}; |
PR types
Performance optimization
PR changes
OPs
Describe
通过elementwise 接口优化了wehere_op和abs_grad_op。 elementwise 接口打包了一系列性能优化技巧,对于有elementwise行为的op有通用的性能提升。通过重写functor的形式,将代码里的循环遍历元素改写为通过elementwise接口调用functor实现。