Skip to content

Commit

Permalink
Fixing batched half precision complex GEMM (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Oct 17, 2022
1 parent 5d9f8e0 commit b0398ca
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ class matxMatMulHandle_t {
cublasLtMatrixLayoutDestroy(Adesc);
cublasLtMatmulDescDestroy(operationDesc);
}

matxFree(a_hp);
matxFree(b_hp);
matxFree(c_hp);
}

/**
Expand Down Expand Up @@ -412,6 +416,9 @@ class matxMatMulHandle_t {
cublasLtMatrixLayout_t BtransformDesc = nullptr;
cublasLtMatrixLayout_t CtransformDesc = nullptr;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
void *c_hp = nullptr; // Make these void since they only work on complex types
void *a_hp = nullptr;
void *b_hp = nullptr;
size_t workspaceSize = 1 << 25UL; // 16MB buffer suggested by cuBLAS team
void *workspace = nullptr;
detail::MatMulParams_t params_;
Expand Down Expand Up @@ -634,21 +641,24 @@ class matxMatMulHandle_t {

auto a_shape = a.Shape();
*(a_shape.begin() + a.Rank() - 2) = a.Size(a.Rank() - 2) * 2;
auto a_planar = make_tensor<typename T1::value_type>(a_shape, MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc(&a_hp, a.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
auto a_planar = make_tensor<typename T2::value_type>(reinterpret_cast<typename T2::value_type*>(a_hp), a_shape, false);

auto b_shape = b.Shape();
*(b_shape.begin() + b.Rank() - 2) = b.Size(b.Rank() - 2) * 2;
auto b_planar = make_tensor<typename T2::value_type>(b_shape, MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc(&b_hp, b.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
auto b_planar = make_tensor<typename T3::value_type>(reinterpret_cast<typename T3::value_type*>(b_hp), b_shape, false);

auto c_shape = c.Shape();
*(c_shape.begin() + c.Rank() - 2) = c.Size(c.Rank() - 2) * 2;
auto c_planar = make_tensor<typename T3::value_type>(c_shape, MATX_ASYNC_DEVICE_MEMORY, stream);
matxAlloc(&c_hp, c.Bytes(), MATX_ASYNC_DEVICE_MEMORY, stream);
auto c_planar = make_tensor<typename T1::value_type>(reinterpret_cast<typename T1::value_type*>(c_hp), c_shape, false);

// Convert A/B to planar layout
(a_planar = planar(a)).run(stream);
(b_planar = planar(b)).run(stream);

// update poitners to planar data.
// update pointers to planar data.
// must use Reset because types for planar are different
a_adj.Reset(reinterpret_cast<T1 *>(a_planar.Data()));
b_adj.Reset(reinterpret_cast<T2 *>(b_planar.Data()));
Expand Down Expand Up @@ -701,11 +711,11 @@ class matxMatMulHandle_t {
}
else {
for (size_t iter = 0; iter < total_iter; iter++) {

// Get pointers into A/B/C for this round
auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx);
auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx);
auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx);

auto res = cublasLtMatmul(
ltHandle, operationDesc, &salpha, (void *)ap,
Adesc, (void *)bp, Bdesc, &sbeta,
Expand Down

0 comments on commit b0398ca

Please sign in to comment.