Skip to content

Commit

Permalink
Added 4D matmul unit test and fixed batching bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick committed Oct 15, 2022
1 parent ce04d98 commit 11720a2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
8 changes: 5 additions & 3 deletions include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,15 +705,17 @@ class matxMatMulHandle_t {
auto ap = std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx);
auto bp = std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx);
auto cp = std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx);

auto res = cublasLtMatmul(
ltHandle, operationDesc, &salpha, (void *)ap,
Adesc, (void *)&bp, Bdesc, &sbeta,
(void *)&cp, Cdesc, (void *)&cp,
Adesc, (void *)bp, Bdesc, &sbeta,
(void *)cp, Cdesc, (void *)cp,
Cdesc, &heuristicResult.algo, workspace, workspaceSize,
stream);

MATX_ASSERT(res == CUBLAS_STATUS_SUCCESS, matxMatMulError);

// Update all but the last 2 indices
// Update all but the last 3 indices
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, idx, 3);
}
}
Expand Down
27 changes: 26 additions & 1 deletion test/00_transform/MatMul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,32 @@ TYPED_TEST(MatMulTestFloatTypes, MediumRectBatched)
tensor_t<TypeParam, 3> c{{batches, m, n}};

this->pb->template InitAndRunTVGenerator<TypeParam>(
"00_transforms", "matmul_operators", "run", {m, k, n, batches});
"00_transforms", "matmul_operators", "run", {batches, m, k, n});

this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");

matmul<decltype(c), decltype(a), decltype(b), PROVIDER_TYPE_CUBLASLT>(c, a, b);

MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh);

MATX_EXIT_HANDLER();
}

TYPED_TEST(MatMulTestFloatTypes, MediumRectBatched4D)
{
MATX_ENTER_HANDLER();
// constexpr index_t batches = 5;
// constexpr index_t m = 128;
// constexpr index_t k = 256;
// constexpr index_t n = 512;

auto a = make_tensor<TypeParam>({5, 5, 128, 256});
auto b = make_tensor<TypeParam>({5, 5, 256, 512});
auto c = make_tensor<TypeParam>({5, 5, 128, 512});

this->pb->template InitAndRunTVGenerator<TypeParam>(
"00_transforms", "matmul_operators", "run", {5, 5, 128, 256, 512});

this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");
Expand Down
11 changes: 5 additions & 6 deletions test/test_vectors/generators/00_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,17 @@ def conv2d(self):
class matmul_operators:
def __init__(self, dtype: str, size: List[int]):
np.random.seed(1234)
batches = 1 if len(size) == 3 else size[-1]
self.size = size
self.dtype = dtype
if batches == 1:
if len(size) == 3:
self.res = {
'a': matx_common.randn_ndarray((size[0], size[1]), dtype),
'b': matx_common.randn_ndarray((size[1], size[2]), dtype)
'a': matx_common.randn_ndarray((size[-3], size[-2]), dtype),
'b': matx_common.randn_ndarray((size[-2], size[-1]), dtype)
}
else:
self.res = {
'a': matx_common.randn_ndarray((batches, size[0], size[1]), dtype),
'b': matx_common.randn_ndarray((batches, size[1], size[2]), dtype)
'a': matx_common.randn_ndarray((*size[:-3], size[-3], size[-2]), dtype),
'b': matx_common.randn_ndarray((*size[:-3], size[-2], size[-1]), dtype)
}

def run(self) -> Dict[str, np.ndarray]:
Expand Down

0 comments on commit 11720a2

Please sign in to comment.