-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto tune for cutlass #50809
Auto tune for cutlass #50809
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
GatherGemmScatterGetCache() { | ||
return autotune::AutoTuneCache::Instance().Get( | ||
AlgorithmType::kGatherGemmScatterFP16NN); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GatherGemmScatterGetCache
函数移动到 cache_base.h
目录下更合适.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我将GatherGemmScatterGetCache
重命名为GetGatherGemmScatter
移动到cache.h
,作为AutoTuneCache
的成员函数了。
const int32_t* b_indices, | ||
const int32_t* c_d_indices, | ||
T alpha, | ||
T beta) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然类的模板参数中已经含有了变参模板 typename... Args
,run 函数的书写可以仿照class MatmulAutoTuner
做简化,函数体内部也可以做同样地简化.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已简化。
static_cast<const int32_t*>(c_d_indices), | ||
static_cast<cutlass::half_t>(1), | ||
static_cast<cutlass::half_t>(1)); | ||
GatherGemmScatter(dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static void dispatchKernel(const GPUContext& dev_ctx,
函数已经加入了模板参数 template <typename T>
,const phi::DataType type
应该不用传入了,配合模板参数if (type == phi::DataType::FLOAT16) {
这样的分支语句也可替换掉了.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已简化,去掉了dispatchKernel
,使用对GatherGemmScatterDriver
(即原来的GatherGemmScatter
)做partial template specialization的写法。
nullptr, | ||
static_cast<const int32_t*>(c_d_indices), | ||
static_cast<float>(alpha), | ||
static_cast<float>(beta)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLOAT64
的分支删除后,会采用什么措施补充吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个pr的做法是,fp64不走融合分支。因为模型不用fp64。
for (auto i = 1; i < fp16_kernels.size(); i++) | ||
tuner->AddCallBack(fp16_kernels[i]); | ||
|
||
size_t key = autotune::GenKey(m / features_num_range, n, k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议把 T 模板参数转换为 phi::DataType 也传入到GenKey
中,因为看到这个调用函数应该是同时针对 fp16和fp32,可能同样的输入M,N.K, 针对fp16和fp32两种类型,最佳计算Implement 也不相同的情况出现.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cutlass提供的fp16和fp32的kernel互不通用,因此这里没有传入phi::DataType,而是用了partial template specialization,fp16和fp32分别在各自的候选kernel中搜索。
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/data_type.h" | ||
#include "paddle/phi/kernels/autotune/auto_tune_base.h" | ||
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_gemm_operations.h
这个头文件是否忘记在PR中提交上来?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件是在编译期生成的,见PR50364
void Run(const phi::GPUContext& ctx, | ||
const size_t key, | ||
T const alpha, | ||
T const beta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha, beta
两个参数的性能,看着也能被 Args... args
存储了.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JamesLim-sy
GatherGemmScatter
会做一次matrix_c += alpha*matrix_a*matrix_b + beta*matrix_c
。为避免在PickBestKernel
时改变matrix_c
的值,需要令alpha=0; beta=1;
。因此此处将alpha, beta
unpack。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, good work!
PR types
New features
PR changes
OPs
Describe