Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable preshuffled mixed dtype Cutlass Gemm #3722

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,54 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class F8I4ShuffledGemm(F8I4RowwiseGemm):
def _int4_row_quantize(
self,
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)

max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
max_int = 2 ** (n_bit - 1)
min_int = -(2 ** (n_bit - 1))
scales = max_val.clamp(min=1e-6) / max_int

out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)

# Cast to int8 and restore shape.
out = out.to(dtype=torch.int8).reshape(x.shape)

# View scales as rows, groups.
scales = scales.view(x.shape[0], -1)

return out, scales

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
# Shuffle weights and scales for faster compute.
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
return out

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)

@property
def name(self) -> str:
return "cutlass_f8i4_preshuffle"


@register_quantize_op
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ at::Tensor bf16i4bf16_rowwise_impl(
// threadblocks in a
// cluster
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
// threadblocks in a
// cluster
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ struct GroupedGemmConfigs {
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

// Implement rowwise scaling epilogue.
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcastPtrArray<
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue*,
ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using WScale = cutlass::epilogue::fusion::Sm90RowBroadcastPtrArray<
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue*,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ at::Tensor f8i4bf16_rowwise_impl(
// threadblocks in a
// cluster
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
Expand Down Expand Up @@ -260,7 +259,7 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
return f8i4bf16_rowwise_impl<
128,
256,
64,
128,
2,
1,
1,
Expand All @@ -271,7 +270,7 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
return f8i4bf16_rowwise_impl<
128,
256,
64,
128,
2,
1,
1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"

#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass/util/packed_stride.hpp"

namespace fbgemm_gpu {

at::Tensor f8i4bf16_shuffled(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale) {
// Get shape information from input tensors.
int M = XQ.size(0);
int K = XQ.size(1);
int N = WQ.size(0);
// Make sure w_scale is in proper format.
TORCH_CHECK(
w_scale.size(2) == 8,
"Weights and scales must be prepacked with preshuffle_i4.");
int num_groups = w_scale.size(1);
int group_size = K / num_groups;
// Allocate output.
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

// Define input types.
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value;

// A Matrix configuration.
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

// B Matrix Configuration.
using ElementB = QuantType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

// We need to manually swap and transpose inputs. Unclear how required this is
// though.
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;

using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;

// Define layout for shuffled weight tensor.
using LayoutAtomQuant =
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));

using ElementScale = MmaType;

// Output Matrix configuration.
using ElementC = cutlass::bfloat16_t;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

// Core kernel configurations
using ElementAccumulator = float;
using ElementCompute = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// TODO tune these shapes.
using TileShape = cute::Shape<cute::_128, cute::_128, cute::Int<TileShapeK>>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
// TODO Should we use fast accum here?
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
// Might be the only epilogue schedule that supports swap + transpose.
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

// Define EVT for rowwise scaling.
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementAccumulator,
ElementAccumulator,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
ElementC, // First stage output type.
ElementAccumulator, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;

using EpilogueEVT =
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
ElementC,
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
AlignmentC,
ElementC,
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
AlignmentC,
EpilogueSchedule,
EpilogueEVT>::CollectiveOp;

using CollectiveMainloopShuffled =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
LayoutB_Reordered,
AlignmentB,
ElementA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>,
CollectiveMainloopShuffled,
CollectiveEpilogue>;

using GemmShuffled =
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;

using StrideC = typename GemmKernelShuffled::StrideC;

/// Initialization
auto shape_B = cute::make_shape(N, K, 1);
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
StrideC stride_C =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
LayoutB_Reordered layout_B_reordered =
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, 1));

// Define Gemm arguments.
typename GemmShuffled::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
layout_B_reordered,
reinterpret_cast<ElementA*>(XQ.data_ptr()),
stride_A,
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
stride_S,
group_size},
{{},
reinterpret_cast<ElementC*>(Y.data_ptr()),
stride_C,
reinterpret_cast<ElementC*>(Y.data_ptr()),
stride_C}};

arguments.epilogue.thread = {
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
{}, // Accumulator
{}, // Multiplies
};

// Launch the workload.
GemmShuffled gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());

if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Loading
Loading