-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto tune for cutlass #50809
Merged
Merged
Auto tune for cutlass #50809
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d037457
commit for saving, not work now :(
umiswing 63decbd
finally it pass compilation...
umiswing 5a1f9a3
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
umiswing 865f12a
change GetKey() to GenKey()
umiswing 5084bc6
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
umiswing a21618f
works for fp16 and fp32 on sm 80.
umiswing 630319d
clean the code.
umiswing 31984a0
remove scripts for sm 70
umiswing cd414d9
remove some comment
umiswing 4f24f11
remove some unused header.
umiswing 62c8120
restructure code.
umiswing 384c34e
restructure more codes.
umiswing 1b8072b
remove some unused codes.
umiswing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,84 @@ class MatmulAutoTuner | |
} | ||
}; | ||
|
||
template <typename T> | ||
typename std::enable_if<std::is_same<T, float>::value, | ||
AlgorithmsCacheMap&>::type | ||
GatherGemmScatterGetCache() { | ||
return autotune::AutoTuneCache::Instance().Get( | ||
AlgorithmType::kGatherGemmScatterFP32NN); | ||
} | ||
|
||
template <typename T> | ||
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value, | ||
AlgorithmsCacheMap&>::type | ||
GatherGemmScatterGetCache() { | ||
return autotune::AutoTuneCache::Instance().Get( | ||
AlgorithmType::kGatherGemmScatterFP16NN); | ||
} | ||
|
||
template <typename T, typename ReturnType, typename... Args> | ||
class GatherGemmScatterAutoTuner | ||
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> { | ||
public: | ||
static GatherGemmScatterAutoTuner<T, ReturnType, Args...>* Instance( | ||
ReturnType (*func)(Args...)) { | ||
static std::once_flag gather_gemm_scatter_init_flag; | ||
static std::unique_ptr<GatherGemmScatterAutoTuner<T, ReturnType, Args...>> | ||
instance; | ||
std::call_once(gather_gemm_scatter_init_flag, [&] { | ||
auto obj = MakeCallback<T>(func); | ||
instance.reset(new GatherGemmScatterAutoTuner<T, ReturnType, Args...>); | ||
instance->AddCallBack(func); | ||
}); | ||
return instance.get(); | ||
} | ||
void Run(const phi::GPUContext& ctx, | ||
const size_t key, | ||
const T* const a, | ||
const T* const b, | ||
const T* const c, | ||
T* const d, | ||
const int& m, | ||
const int& n, | ||
const int& k, | ||
const int32_t* a_indices, | ||
const int32_t* b_indices, | ||
const int32_t* c_d_indices, | ||
T alpha, | ||
T beta) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 既然类的模板参数中已经含有了变参模板 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已简化。 |
||
this->is_init_ = true; | ||
this->CheckKernelSize(); | ||
auto& cache = GatherGemmScatterGetCache<T>(); | ||
|
||
if (cache.Find(key)) { | ||
auto best_idx = cache.Get(key); | ||
this->kernels_[best_idx].Run( | ||
ctx, a, b, c, d, m, n, k, a_indices, c_d_indices, alpha, beta); | ||
|
||
} else { | ||
// Set alpha to 0 and beta to 1 to avoid changing the value of d when | ||
// picking the best kernel | ||
auto best_idx = this->PickBestKernel(ctx, | ||
ctx, | ||
a, | ||
b, | ||
c, | ||
d, | ||
m, | ||
n, | ||
k, | ||
a_indices, | ||
c_d_indices, | ||
static_cast<T>(0), | ||
static_cast<T>(1)); | ||
cache.Set(key, best_idx); | ||
this->kernels_[best_idx].Run( | ||
ctx, a, b, c, d, m, n, k, a_indices, c_d_indices, alpha, beta); | ||
} | ||
} | ||
}; | ||
|
||
// Define the auto_tuner inital object. | ||
#define DEFINE_AUTOTUNER_COMMON_OBJ(name) \ | ||
template <typename T, typename ReturnType, typename... Args> \ | ||
|
@@ -211,6 +289,7 @@ class MatmulAutoTuner | |
|
||
DEFINE_AUTOTUNER(Transpose) | ||
DEFINE_AUTOTUNER_FN(Matmul) | ||
DEFINE_AUTOTUNER_FN(GatherGemmScatter) | ||
|
||
#undef DEFINE_AUTOTUNER_COMMON_OBJECT | ||
#undef DEFINE_AUTOTUNER_FN | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GatherGemmScatterGetCache
函数移动到cache_base.h
目录下更合适.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我将
GatherGemmScatterGetCache
重命名为GetGatherGemmScatter
移动到cache.h
,作为AutoTuneCache
的成员函数了。