Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto tune for cutlass #50809

Merged
merged 13 commits into from
Mar 15, 2023
Merged

Auto tune for cutlass #50809

merged 13 commits into from
Mar 15, 2023

Conversation

umiswing
Copy link
Member

@umiswing umiswing commented Feb 23, 2023

PR types

New features

PR changes

OPs

Describe

  1. 在auto tune中增加了cutlass的gather-gemm-scatter融合的自动调优功能。默认开启调优。
  2. sparse conv3d实现中,涉及shape为(m, n, k)的GEMM。m与features num相关,变化较大。为了防止大量重复搜索,将(m, n, k)的shape映射为(m/features_num_range, n, k)。features_num_range当前设为1e4。后续可能根据推理和训练的情况调整features_num_range的大小。
  3. 去除了手写规则。
  4. 由于cutlass在sm 70上的gemm-scatter实现有问题,因此去除了生成规则中的sm 70部分。本PR支持了sm 80。后续将增加对sm 75的支持。
  5. 支持自动调优的数据类型:fp16、fp32

@paddle-bot
Copy link

paddle-bot bot commented Feb 23, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@umiswing umiswing requested review from JamesLim-sy and Xreki March 7, 2023 08:32
GatherGemmScatterGetCache() {
return autotune::AutoTuneCache::Instance().Get(
AlgorithmType::kGatherGemmScatterFP16NN);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GatherGemmScatterGetCache 函数移动到 cache_base.h目录下更合适.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我将GatherGemmScatterGetCache重命名为GetGatherGemmScatter移动到cache.h,作为AutoTuneCache的成员函数了。

const int32_t* b_indices,
const int32_t* c_d_indices,
T alpha,
T beta) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然类的模板参数中已经含有了变参模板 typename... Args,run 函数的书写可以仿照class MatmulAutoTuner 做简化,函数体内部也可以做同样地简化.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已简化。

static_cast<const int32_t*>(c_d_indices),
static_cast<cutlass::half_t>(1),
static_cast<cutlass::half_t>(1));
GatherGemmScatter(dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static void dispatchKernel(const GPUContext& dev_ctx, 函数已经加入了模板参数 template <typename T>const phi::DataType type应该不用传入了,配合模板参数if (type == phi::DataType::FLOAT16) { 这样的分支语句也可替换掉了.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已简化,去掉了dispatchKernel,使用对GatherGemmScatterDriver(即原来的GatherGemmScatter)做partial template specialization的写法。

nullptr,
static_cast<const int32_t*>(c_d_indices),
static_cast<float>(alpha),
static_cast<float>(beta));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLOAT64 的分支删除后,会采用什么措施补充吗?

Copy link
Member Author

@umiswing umiswing Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个pr的做法是,fp64不走融合分支。因为模型不用fp64。

for (auto i = 1; i < fp16_kernels.size(); i++)
tuner->AddCallBack(fp16_kernels[i]);

size_t key = autotune::GenKey(m / features_num_range, n, k);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把 T 模板参数转换为 phi::DataType 也传入到GenKey中,因为看到这个调用函数应该是同时针对 fp16和fp32,可能同样的输入M,N.K, 针对fp16和fp32两种类型,最佳计算Implement 也不相同的情况出现.

Copy link
Member Author

@umiswing umiswing Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass提供的fp16和fp32的kernel互不通用,因此这里没有传入phi::DataType,而是用了partial template specialization,fp16和fp32分别在各自的候选kernel中搜索。

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_gemm_operations.h 这个头文件是否忘记在PR中提交上来?

Copy link
Member Author

@umiswing umiswing Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件是在编译期生成的,见PR50364

@umiswing umiswing changed the title [WIP] Auto tune for cutlass Auto tune for cutlass Mar 7, 2023
void Run(const phi::GPUContext& ctx,
const size_t key,
T const alpha,
T const beta,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha, beta 两个参数的性能,看着也能被 Args... args 存储了.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesLim-sy
GatherGemmScatter会做一次matrix_c += alpha*matrix_a*matrix_b + beta*matrix_c。为避免在PickBestKernel时改变matrix_c的值,需要令alpha=0; beta=1;。因此此处将alpha, betaunpack。

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, good work!

@zkh2016 zkh2016 merged commit 12d43da into PaddlePaddle:develop Mar 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants