Skip to content

Commit

Permalink
simplify GKO_REGISTER_OPERATION dispatch
Browse files Browse the repository at this point in the history
Use the parameter type instead of an int parameter.
  • Loading branch information
upsj committed Aug 31, 2021
1 parent 01682a1 commit 8a2221c
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions include/ginkgo/core/base/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,6 @@ class Operation {
namespace detail {


constexpr static int ref_exec_tag = 0;
constexpr static int omp_exec_tag = 1;
constexpr static int cuda_exec_tag = 2;
constexpr static int hip_exec_tag = 3;
constexpr static int dpcpp_exec_tag = 4;


/**
* The RegisteredOperation class wraps a functor that will be called with a tag
* parameter based on the dynamic type of the executor that runs it.
Expand Down Expand Up @@ -326,27 +319,27 @@ class RegisteredOperation : public Operation {

void run(std::shared_ptr<const ReferenceExecutor> exec) const override
{
op_(exec, ref_exec_tag);
op_(exec);
}

void run(std::shared_ptr<const OmpExecutor> exec) const override
{
op_(exec, omp_exec_tag);
op_(exec);
}

void run(std::shared_ptr<const CudaExecutor> exec) const override
{
op_(exec, cuda_exec_tag);
op_(exec);
}

void run(std::shared_ptr<const HipExecutor> exec) const override
{
op_(exec, hip_exec_tag);
op_(exec);
}

void run(std::shared_ptr<const DpcppExecutor> exec) const override
{
op_(exec, dpcpp_exec_tag);
op_(exec);
}

private:
Expand Down Expand Up @@ -440,43 +433,51 @@ RegisteredOperation<Closure> make_register_operation(const char *name,
*/
#define GKO_REGISTER_OPERATION(_name, _kernel) \
template <typename... Args> \
auto make_##_name(Args &&... args) \
auto make_##_name(Args &&...args) \
{ \
return ::gko::detail::make_register_operation( \
#_name, sizeof...(Args), \
[&args...](std::shared_ptr<const ::gko::Executor> exec, int tag) { \
switch (tag) { \
case ::gko::detail::ref_exec_tag: \
#_name, sizeof...(Args), [&args...](auto exec) { \
if (std::is_same< \
decltype(exec), \
std::shared_ptr<const ::gko::ReferenceExecutor>>:: \
value) { \
::gko::kernels::reference::_kernel( \
std::static_pointer_cast< \
std::dynamic_pointer_cast< \
const ::gko::ReferenceExecutor>(exec), \
std::forward<Args>(args)...); \
return; \
case ::gko::detail::omp_exec_tag: \
} else if (std::is_same< \
decltype(exec), \
std::shared_ptr<const ::gko::OmpExecutor>>:: \
value) { \
::gko::kernels::omp::_kernel( \
std::static_pointer_cast<const ::gko::OmpExecutor>( \
std::dynamic_pointer_cast<const ::gko::OmpExecutor>( \
exec), \
std::forward<Args>(args)...); \
return; \
case ::gko::detail::cuda_exec_tag: \
} else if (std::is_same< \
decltype(exec), \
std::shared_ptr<const ::gko::CudaExecutor>>:: \
value) { \
::gko::kernels::cuda::_kernel( \
std::static_pointer_cast<const ::gko::CudaExecutor>( \
std::dynamic_pointer_cast<const ::gko::CudaExecutor>( \
exec), \
std::forward<Args>(args)...); \
return; \
case ::gko::detail::hip_exec_tag: \
} else if (std::is_same< \
decltype(exec), \
std::shared_ptr<const ::gko::HipExecutor>>:: \
value) { \
::gko::kernels::hip::_kernel( \
std::static_pointer_cast<const ::gko::HipExecutor>( \
std::dynamic_pointer_cast<const ::gko::HipExecutor>( \
exec), \
std::forward<Args>(args)...); \
return; \
case ::gko::detail::dpcpp_exec_tag: \
} else if (std::is_same< \
decltype(exec), \
std::shared_ptr<const ::gko::DpcppExecutor>>:: \
value) { \
::gko::kernels::dpcpp::_kernel( \
std::static_pointer_cast<const ::gko::DpcppExecutor>( \
std::dynamic_pointer_cast<const ::gko::DpcppExecutor>( \
exec), \
std::forward<Args>(args)...); \
return; \
default: \
} else { \
GKO_NOT_IMPLEMENTED; \
} \
}); \
Expand Down

0 comments on commit 8a2221c

Please sign in to comment.