From 4622d0b45c35801d4375848a1d25fcd59793dcc1 Mon Sep 17 00:00:00 2001 From: cliffburdick Date: Sun, 16 Oct 2022 19:32:24 -0700 Subject: [PATCH] Fixing batched half precision complex GEMM --- include/matx/transforms/matmul.h | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/include/matx/transforms/matmul.h b/include/matx/transforms/matmul.h index 57085ada..198a98b5 100644 --- a/include/matx/transforms/matmul.h +++ b/include/matx/transforms/matmul.h @@ -343,6 +343,10 @@ class matxMatMulHandle_t { cublasLtMatrixLayoutDestroy(Adesc); cublasLtMatmulDescDestroy(operationDesc); } + + matxFree(a_hp); + matxFree(b_hp); + matxFree(c_hp); } /** @@ -412,6 +416,9 @@ class matxMatMulHandle_t { cublasLtMatrixLayout_t BtransformDesc = nullptr; cublasLtMatrixLayout_t CtransformDesc = nullptr; cublasLtMatmulHeuristicResult_t heuristicResult = {}; + void *c_hp = nullptr; // Make these void since they only work on complex types + void *a_hp = nullptr; + void *b_hp = nullptr; size_t workspaceSize = 1 << 25UL; // 16MB buffer suggested by cuBLAS team void *workspace = nullptr; detail::MatMulParams_t params_; @@ -634,21 +641,24 @@ class matxMatMulHandle_t { auto a_shape = a.Shape(); *(a_shape.begin() + a.Rank() - 2) = a.Size(a.Rank() - 2) * 2; - auto a_planar = make_tensor(a_shape, MATX_ASYNC_DEVICE_MEMORY, stream); + matxAlloc(&a_hp, a.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto a_planar = make_tensor(reinterpret_cast(a_hp), a_shape, false); auto b_shape = b.Shape(); *(b_shape.begin() + b.Rank() - 2) = b.Size(b.Rank() - 2) * 2; - auto b_planar = make_tensor(b_shape, MATX_ASYNC_DEVICE_MEMORY, stream); + matxAlloc(&b_hp, b.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto b_planar = make_tensor(reinterpret_cast(b_hp), b_shape, false); auto c_shape = c.Shape(); *(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2; - auto c_planar = make_tensor(c_shape, MATX_ASYNC_DEVICE_MEMORY, stream); + matxAlloc(&c_hp, c.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream); + auto c_planar = make_tensor(reinterpret_cast(c_hp), c_shape, false); // Convert A/B to planar layout (a_planar = planar(a)).run(stream); (b_planar = planar(b)).run(stream); - // update poitners to planar data. + // update pointers to planar data. // must use Reset because types for planar are different a_adj.Reset(reinterpret_cast(a_planar.Data())); b_adj.Reset(reinterpret_cast(b_planar.Data())); @@ -701,11 +711,11 @@ class matxMatMulHandle_t { } else { for (size_t iter = 0; iter < total_iter; iter++) { + // Get pointers into A/B/C for this round auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx); auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx); auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx); - auto res = cublasLtMatmul( ltHandle, operationDesc, &salpha, (void *)ap, Adesc, (void *)bp, Bdesc, &sbeta,