-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pten]add pten conj kernel #38247
[pten]add pten conj kernel #38247
Conversation
Thanks for your contribution! |
paddle/pten/kernels/conj_kernel.cc
Outdated
#include "paddle/pten/kernels/conj_impl.h" | ||
using complex64 = ::paddle::platform::complex<float>; | ||
using complex128 = ::paddle::platform::complex<double>; | ||
PT_REGISTER_CTX_KERNEL(conj, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在kernels根目录的kernel是支持全设备的,这个还做不到,需要移到对应backend目录
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照现在的实现,咱们应该是conj functor在functions目录(hybird名字改回去),注册分别在cpu和cuda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经移动到cpu/cuda 下了
paddle/pten/kernels/conj_impl.h
Outdated
namespace pten { | ||
|
||
template <typename T, typename ContextT> | ||
void Conj(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件中应该只有声明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/pten/kernels/CMakeLists.txt
Outdated
@@ -1,3 +1,11 @@ | |||
if(WITH_GPU) | |||
nv_library(conj_kernel SRCS conj_kernel.cc conj_kernel.cu DEPS dense_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里依赖参考下其他kernel的依赖,编译可能会出错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改编译依赖
paddle/pten/kernels/CMakeLists.txt
Outdated
nv_library(conj_kernel SRCS conj_kernel.cc conj_kernel.cu DEPS dense_tensor) | ||
elseif(WITH_ROCM) | ||
hip_library(conj_kernel SRCS conj_kernel.cc conj_kernel.cu DEPS dense_tensor) | ||
else() | ||
cc_library(conj_kernel SRCS conj_kernel.cc DEPS dense_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可能需要增加kernel_context kernel_factory
的依赖
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经增加
paddle/pten/kernels/conj_kernel.cu
Outdated
|
||
#include "paddle/pten/core/kernel_registry.h" | ||
#include "paddle/pten/kernels/conj_impl.h" | ||
#include "paddle/pten/kernels/conj_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/pten/kernels/conj_kernel.h
放在include最前面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
output : Tensor | ||
infer_meta : | ||
func : UnchangedInferMeta | ||
param : [x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里API的输入和InferMeta完全相同,可以不配置param,使用默认值即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,thansk
param : [x] | ||
kernel : | ||
func : conj | ||
param : [x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同InferMeta,可以不配置param,使用默认值即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续还需要按照新的方式调整
PR types
Others
PR changes
Others
Describe
Add pten Conj API and kernel