diff --git a/include/matx/transforms/matmul.h b/include/matx/transforms/matmul.h index 31999aff8..57085ada0 100644 --- a/include/matx/transforms/matmul.h +++ b/include/matx/transforms/matmul.h @@ -705,15 +705,17 @@ class matxMatMulHandle_t { auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx); auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx); auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx); + auto res = cublasLtMatmul( ltHandle, operationDesc, &salpha, (void *)ap, - Adesc, (void *)&bp, Bdesc, &sbeta, - (void *)&cp, Cdesc, (void *)&cp, + Adesc, (void *)bp, Bdesc, &sbeta, + (void *)cp, Cdesc, (void *)cp, Cdesc, &heuristicResult.algo, workspace, workspaceSize, stream); + MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError); - // Update all but the last 2 indices + // Update all but the last 3 indices UpdateIndices(a_adj, idx, 3); } } diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index 855b6038d..e4304ba9e 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -255,7 +255,32 @@ TYPED_TEST(MatMulTestFloatTypes, MediumRectBatched) tensor_t c{{batches, m, n}}; this->pb->template InitAndRunTVGenerator( - "00_transforms", "matmul_operators", "run", {m, k, n, batches}); + "00_transforms", "matmul_operators", "run", {batches, m, k, n}); + + this->pb->NumpyToTensorView(a, "a"); + this->pb->NumpyToTensorView(b, "b"); + + matmul(c, a, b); + + MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh); + + MATX_EXIT_HANDLER(); +} + +TYPED_TEST(MatMulTestFloatTypes, MediumRectBatched4D) +{ + MATX_ENTER_HANDLER(); + // constexpr index_t batches = 5; + // constexpr index_t m = 128; + // constexpr index_t k = 256; + // constexpr index_t n = 512; + + auto a = make_tensor({5, 5, 128, 256}); + auto b = make_tensor({5, 5, 256, 512}); + auto c = make_tensor({5, 5, 128, 512}); + + this->pb->template InitAndRunTVGenerator( + "00_transforms", "matmul_operators", "run", {5, 5, 128, 256, 512}); this->pb->NumpyToTensorView(a, "a"); this->pb->NumpyToTensorView(b, "b"); diff --git a/test/test_vectors/generators/00_transforms.py b/test/test_vectors/generators/00_transforms.py index 74a793666..7558c817a 100755 --- a/test/test_vectors/generators/00_transforms.py +++ b/test/test_vectors/generators/00_transforms.py @@ -53,18 +53,18 @@ def conv2d(self): class matmul_operators: def __init__(self, dtype: str, size: List[int]): np.random.seed(1234) - batches = 1 if len(size) == 3 else size[-1] self.size = size self.dtype = dtype - if batches == 1: + if len(size) == 3: self.res = { - 'a': matx_common.randn_ndarray((size[0], size[1]), dtype), - 'b': matx_common.randn_ndarray((size[1], size[2]), dtype) + 'a': matx_common.randn_ndarray((size[-3], size[-2]), dtype), + 'b': matx_common.randn_ndarray((size[-2], size[-1]), dtype) } else: + print(*size[:-3]) self.res = { - 'a': matx_common.randn_ndarray((batches, size[0], size[1]), dtype), - 'b': matx_common.randn_ndarray((batches, size[1], size[2]), dtype) + 'a': matx_common.randn_ndarray((*size[:-3], size[-3], size[-2]), dtype), + 'b': matx_common.randn_ndarray((*size[:-3], size[-2], size[-1]), dtype) } def run(self) -> Dict[str, np.ndarray]: