-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move transpose to pten #39327
Move transpose to pten #39327
Conversation
Thanks for your contribution! |
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
@@ -43,7 +43,7 @@ def check_network_convergence(cls, | |||
get_data_from_feeder=None, | |||
use_parallel_executor=True, | |||
use_reduce=False, | |||
use_ir_memory_optimize=True, | |||
use_ir_memory_optimize=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改了会导致一些case覆盖不到吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个会导致 memory optimize会挂。。我问了,已经没有人再用这个组件了。
auto* out_ptr = out->data<T>(); | ||
|
||
// copy in_stride, out_stride, axis to gpu device | ||
const paddle::platform::CUDAPlace& cuda_place = context.GetPlace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议phi::CUDAPlace?
|
||
// copy in_stride, out_stride, axis to gpu device | ||
const paddle::platform::CUDAPlace& cuda_place = context.GetPlace(); | ||
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
const int ndims, const Tensor& in, | ||
const std::vector<int32_t> perm, Tensor* out) { | ||
void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims, | ||
const Tensor& in, const std::vector<int32_t> perm, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一个小的优化点,perm参数是不可以使用引用&?
#include <vector> | ||
#include "paddle/phi/api/ext/dispatch.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/transpose_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"paddle/phi/kernels/transpose_kernel.h"在开头比较符合规范
#pragma once | ||
|
||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
#include "paddle/phi/kernels/transpose_grad_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不不需要#include "paddle/phi/kernels/transpose_grad_kernel.h"
?
… move_transpose_to_pten
… move_transpose_to_pten
… move_transpose_to_pten
|
||
// copy in_stride, out_stride, axis to gpu device | ||
const phi::GPUPlace& cuda_place = context.GetPlace(); | ||
phi::CPUPlace cpu_place = paddle::platform::CPUPlace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::platform::CPUPlace -> phi::CPUPlace
#include "paddle/phi/common/bfloat16.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前向include反向的impl感觉有点怪
|
||
// copy in_stride, out_stride, axis to gpu device | ||
const phi::GPUPlace& cuda_place = context.GetPlace(); | ||
phi::CPUPlace cpu_place = paddle::platform::CPUPlace(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle::platform->phi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Breaking changes
PR changes
OPs
Describe
move transpose op to pten