-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Phi]Move kron kernel to phi #40427
[Phi]Move kron kernel to phi #40427
Conversation
Thanks for your contribution! |
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/kron_grad_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件放到第一行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/impl/kron_kernel_impl.h" | ||
#include "paddle/phi/kernels/kron_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
namespace phi { | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以用phi::dtype 命名空间下的complex ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在phi下最好还是不使用paddle::xxx相关namespace的别名,会增加后续替换的难度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
const plat::complex<T>* dout_; | ||
const plat::complex<T>* A_; | ||
const plat::complex<T>* B_; | ||
plat::complex<T>* dout_a_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plat->phi::dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
p_dout_y = dout_y.data<T>(); | ||
} | ||
|
||
plat::ForRange<Context> for_range(dev_ctx, numel); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用phi下的ForRange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream); | ||
} | ||
if (dy) { | ||
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
*ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1}, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorReduceImpl可以使用phi下的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/platform/for_range.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_registry.h这里应该不需要了
for_range.h使用phi下的
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduce_op.cu.h
可以使用paddle/phi/kernels/funcs/reduce_function.h
代替
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
paddle/phi/ops/compat/kron_sig.cc
Outdated
KernelSignature KronOpArgumentMapping(const ArgumentMappingContext& ctx) { | ||
return KernelSignature("kron", {"X", "Y"}, {}, {"Out"}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的前向ArgumentMapping看上去没有特殊case,感觉可以不写,试试直接使用默认的参数映射能不能work?
auto stream = dev_ctx.stream(); // it is a cuda device_context | ||
auto* ctx = reinterpret_cast<const plat::CUDADeviceContext*>(&dev_ctx); | ||
if (dx) { | ||
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel
*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream); | ||
} | ||
if (dy) { | ||
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
#include <algorithm> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/op_registry.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi目录下用不到原来的op 注册头文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
#include <vector> | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/platform/for_range.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
引用这个 paddle/phi/kernels/funcs/for_range.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel , 此头文件可以不用了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
p_shape_y = dim_y.Get(); | ||
#endif | ||
|
||
paddle::platform::ForRange<Context> for_range(dev_ctx, numel); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用Phi下的ForRange替代
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
[Phi]Move kron kernel to phi