Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto tune for cutlass #50809

Merged
merged 13 commits into from
Mar 15, 2023
8 changes: 5 additions & 3 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND
rm -rf
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build &&
mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm
&& ${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build"
"${CMAKE_CUDA_COMPILER_VERSION}"
INSTALL_COMMAND ""
TEST_COMMAND "")
Expand Down
79 changes: 79 additions & 0 deletions paddle/phi/kernels/autotune/auto_tune_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,84 @@ class MatmulAutoTuner
}
};

template <typename T>
typename std::enable_if<std::is_same<T, float>::value,
AlgorithmsCacheMap&>::type
GatherGemmScatterGetCache() {
return autotune::AutoTuneCache::Instance().Get(
AlgorithmType::kGatherGemmScatterFP32NN);
}

template <typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value,
AlgorithmsCacheMap&>::type
GatherGemmScatterGetCache() {
return autotune::AutoTuneCache::Instance().Get(
AlgorithmType::kGatherGemmScatterFP16NN);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GatherGemmScatterGetCache 函数移动到 cache_base.h目录下更合适.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我将GatherGemmScatterGetCache重命名为GetGatherGemmScatter移动到cache.h,作为AutoTuneCache的成员函数了。


template <typename T, typename ReturnType, typename... Args>
class GatherGemmScatterAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
public:
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance(
ReturnType (*func)(Args...)) {
static std::once_flag gather_gemm_scatter_init_flag;
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>>
instance;
std::call_once(gather_gemm_scatter_init_flag, [&] {
auto obj = MakeCallback<T>(func);
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>);
instance->AddCallBack(func);
});
return instance.get();
}
void Run(const phi::GPUContext& ctx,
const size_t key,
const T* const a,
const T* const b,
const T* const c,
T* const d,
const int& m,
const int& n,
const int& k,
const int32_t* a_indices,
const int32_t* b_indices,
const int32_t* c_d_indices,
T alpha,
T beta) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然类的模板参数中已经含有了变参模板 typename... Args,run 函数的书写可以仿照class MatmulAutoTuner 做简化,函数体内部也可以做同样地简化.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已简化。

this->is_init_ = true;
this->CheckKernelSize();
auto& cache = GatherGemmScatterGetCache<T>();

if (cache.Find(key)) {
auto best_idx = cache.Get(key);
this->kernels_[best_idx].Run(
ctx, a, b, c, d, m, n, k, a_indices, c_d_indices, alpha, beta);

} else {
// Set alpha to 0 and beta to 1 to avoid changing the value of d when
// picking the best kernel
auto best_idx = this->PickBestKernel(ctx,
ctx,
a,
b,
c,
d,
m,
n,
k,
a_indices,
c_d_indices,
static_cast<T>(0),
static_cast<T>(1));
cache.Set(key, best_idx);
this->kernels_[best_idx].Run(
ctx, a, b, c, d, m, n, k, a_indices, c_d_indices, alpha, beta);
}
}
};

// Define the auto_tuner inital object.
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \
template <typename T, typename ReturnType, typename... Args> \
Expand Down Expand Up @@ -211,6 +289,7 @@ class MatmulAutoTuner

DEFINE_AUTOTUNER(Transpose)
DEFINE_AUTOTUNER_FN(Matmul)
DEFINE_AUTOTUNER_FN(GatherGemmScatter)

#undef DEFINE_AUTOTUNER_COMMON_OBJECT
#undef DEFINE_AUTOTUNER_FN
Expand Down
12 changes: 7 additions & 5 deletions paddle/phi/kernels/autotune/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ enum class AlgorithmType {
kConvBackwardFilter = 3,
kTranspose = 4,
kMatmul = 5,
kGatherGemmScatterFP16NN = 6,
kGatherGemmScatterFP32NN = 7,
#if !defined(PADDLE_WITH_CUDNN_FRONTEND)
kAlgorithmCount = 6
kAlgorithmCount = 8
#else
kConvForwardV8 = 6,
kConvBackwardDataV8 = 7,
kConvBackwardFilterV8 = 8,
kAlgorithmCount = 9
kConvForwardV8 = 8,
kConvBackwardDataV8 = 9,
kConvBackwardFilterV8 = 10,
kAlgorithmCount = 11
#endif
};

Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/sparse/gpu/conv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,

#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (in_channels % 8 != 0 || out_channels % 8 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (std::is_same<T, float>::value) cutlass = false;
}
if (std::is_same<T, double>::value) cutlass = false;
if (!std::is_same<IntT, int32_t>::value) cutlass = false;

if (cutlass) {
auto* out_values = out->mutable_non_zero_elements();
T* out_values_ptr = out_values->data<T>();
Expand Down Expand Up @@ -160,7 +161,8 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
K,
gather_indices,
scatter_indices,
cutlass,
1.0f,
1.0f,
x.dtype());
}
} else {
Expand Down
Loading