diff --git a/include/matx/transforms/matmul.h b/include/matx/transforms/matmul.h index 3a5c40339..2bcb15f10 100644 --- a/include/matx/transforms/matmul.h +++ b/include/matx/transforms/matmul.h @@ -77,6 +77,45 @@ union MatMulScaleType_t { double cf64[2]; }; +template +constexpr bool CompatibleGemmTypes() { + if constexpr (!std::is_same_v && + !std::is_same_v && + !std::is_same_v) { + return false; + } + + if (PROV == PROVIDER_TYPE_CUBLASLT) { + if constexpr (std::is_same_v && + std::is_same_v) { + // List of accepted types when A/B/C match + return std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v; + + } + // Accumulator type different from A/B + else if constexpr ( std::is_same_v && + !std::is_same_v) { + return (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v); + } + } + else { + // For now return true for other providers until we support more + return true; + } +} + /** * Parameters needed to execute a GEMM. For the most part, these are very * similar to that of a standard GEMM call @@ -834,7 +873,7 @@ class matxMatMulHandle_t { static_cast( params_.ldc)}, // Tensor-ref for destination matrix D (may be // different memory than source C matrix) - {alpha, beta}); // Scalars used in the Epilogue + {static_cast(alpha), static_cast(beta)}); // Scalars used in the Epilogue CutlassGemm gemm_operator; cutlass::Status status = gemm_operator(args, nullptr, stream); @@ -895,7 +934,7 @@ class matxMatMulHandle_t { params_.ldc)}, // Tensor-ref for destination matrix D (may // be different memory than source C matrix) c_adj.Stride(RANK - 3), // Batch Stride C - {alpha, beta}, + {static_cast(alpha), static_cast(beta)}, params_.batch // Batch Dimension ); // Scalars used in the Epilogue @@ -1118,6 +1157,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, auto A_ = as_type(A); auto B_ = as_type(B); + static_assert(detail::CompatibleGemmTypes(), + "Combination of A/B/C types are not supported"); + + // CublasLt does not support operators and certain transpose modes. // Grab a suppported tensor here and copy in if necessary. auto c = getCublasSupportedTensor(C, stream);