Skip to content

Commit

Permalink
Enable preshuffled mixed dtype Cutlass Gemm (#3722)
Browse files Browse the repository at this point in the history
Summary:

WIP to enable new optimized preshuffled fp8xint4 gemm.

While the example compiles and runs, it runs into a variety of problems. The outputs are either completely incorrect, contain NaNs, or the kernel hits an Illegal Memory Access. I'm not yet sure why.

Differential Revision: D69955197
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 22, 2025
1 parent 3209bb4 commit bbca782
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 0 deletions.
48 changes: 48 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,54 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class F8I4ShuffledGemm(F8I4RowwiseGemm):
def _int4_row_quantize(
self,
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)

max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
max_int = 2 ** (n_bit - 1)
min_int = -(2 ** (n_bit - 1))
scales = max_val.clamp(min=1e-6) / max_int

out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)

# Cast to int8 and restore shape.
out = out.to(dtype=torch.int8).reshape(x.shape)

# View scales as rows, groups.
scales = scales.view(x.shape[0], -1)

return out, scales

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
# Shuffle weights and scales for faster compute.
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
return out

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)

@property
def name(self) -> str:
return "cutlass_f8i4_preshuffle"


@register_quantize_op
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"

#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass/util/packed_stride.hpp"

namespace fbgemm_gpu {

at::Tensor f8i4bf16_shuffled(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale) {
// Get shape information from input tensors.
int M = XQ.size(0);
int K = XQ.size(1);
int N = WQ.size(0);
// Make sure w_scale is in proper format.
TORCH_CHECK(
w_scale.size(2) == 8,
"Weights and scales must be prepacked with preshuffle_i4.");
int num_groups = w_scale.size(1);
int group_size = K / num_groups;
// Allocate output.
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

// Define input types.
using MmaType = cutlass::float_e4m3_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value;

// A Matrix configuration.
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

// B Matrix Configuration.
using ElementB = QuantType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

// We need to manually swap and transpose inputs. Unclear how required this is
// though.
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;

using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;

// Define layout for shuffled weight tensor.
using LayoutAtomQuant =
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));

using ElementScale = MmaType;

// Output Matrix configuration.
using ElementC = cutlass::bfloat16_t;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

// Core kernel configurations
using ElementAccumulator = float;
using ElementCompute = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// TODO tune these shapes.
using TileShape = cute::Shape<cute::_128, cute::_128, cute::Int<TileShapeK>>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
// TODO Should we use fast accum here?
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
// Might be the only epilogue schedule that supports swap + transpose.
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

// Define EVT for rowwise scaling.
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementAccumulator,
ElementAccumulator,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
ElementC, // First stage output type.
ElementAccumulator, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;

using EpilogueEVT =
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementAccumulator,
ElementC,
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
AlignmentC,
ElementC,
typename cutlass::layout::LayoutTranspose<LayoutC>::type,
AlignmentC,
EpilogueSchedule,
EpilogueEVT>::CollectiveOp;

using CollectiveMainloopShuffled =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>,
LayoutB_Reordered,
AlignmentB,
ElementA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>,
CollectiveMainloopShuffled,
CollectiveEpilogue>;

using GemmShuffled =
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;

using StrideC = typename GemmKernelShuffled::StrideC;

/// Initialization
auto shape_B = cute::make_shape(N, K, 1);
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
StrideC stride_C =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
LayoutB_Reordered layout_B_reordered =
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
StrideS stride_S = cutlass::make_cute_packed_stride(
StrideS{}, cute::make_shape(N, num_groups, 1));

// Define Gemm arguments.
typename GemmShuffled::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
layout_B_reordered,
reinterpret_cast<ElementA*>(XQ.data_ptr()),
stride_A,
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
stride_S,
group_size},
{{},
reinterpret_cast<ElementC*>(Y.data_ptr()),
stride_C,
reinterpret_cast<ElementC*>(Y.data_ptr()),
stride_C}};

arguments.epilogue.thread = {
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
{}, // Accumulator
{}, // Multiplies
};

// Launch the workload.
GemmShuffled gemm;

// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}

status = gemm(at::cuda::getCurrentCUDAStream());

if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(
std::string("cutlass cannot run") +
cutlass::cutlassGetStatusString(status));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();

return Y;
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include "cute/layout.hpp"
#include "cutlass/detail/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/util/mixed_dtype_utils.hpp"

namespace fbgemm_gpu {

std::tuple<at::Tensor, at::Tensor> preshuffle_i4(
at::Tensor WQ,
at::Tensor w_scale) {
// Check that w_scale is proper type. if not, quantize it.
if (w_scale.dtype() != at::kFloat8_e4m3fn) {
TORCH_WARN(
"Weight scale must be FP8 for preshuffled GEMM. Performing downcasting.");
w_scale = w_scale.to(WQ.options().dtype(at::kFloat8_e4m3fn));
}
// Start by allocating space for shuffled tensors.
at::Tensor WQ_shuffled = at::empty_like(WQ);
// Packed scale contains 8 lookup values for each original scale element.
at::Tensor w_scale_packed =
at::empty({w_scale.size(0), w_scale.size(1), 8}, w_scale.options());
// WQ has two int4 values packed into each int8 dtype, so the size
// is larger than it seems.
size_t WQ_size = 2 * WQ.numel();
// Encode weights to enable efficient lookup.
cutlass::unified_encode_int4b(
reinterpret_cast<cutlass::int4b_t*>(WQ.data_ptr()),
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
WQ_size);

size_t w_scale_size = w_scale.numel();
cutlass::pack_scale_fp8(
reinterpret_cast<cutlass::float_e4m3_t*>(w_scale.data_ptr()),
reinterpret_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>(
w_scale_packed.data_ptr()),
w_scale_size);

// Next we need to shuffle B. To do this, we define a few helper objects.
const int N = WQ.size(0);
const int K = 2 * WQ.size(1);
auto shape_B = cute::make_shape(N, K, 1);
using LayoutB = cutlass::layout::ColumnMajor;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<
cutlass::float_e4m3_t>());
using LayoutB_Reordered = decltype(cute::tile_to_shape(
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{}));
StrideB stride_B;
auto layout_B = make_layout(shape_B, stride_B);
LayoutB_Reordered layout_B_reordered =
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
;

// Now we're ready to reorder the tensor into proper layout.
cutlass::reorder_tensor(
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()),
layout_B,
layout_B_reordered);

// Tensors should now be preshuffled and ready for use.
return {WQ_shuffled, w_scale_packed};
}

} // namespace fbgemm_gpu
Loading

0 comments on commit bbca782

Please sign in to comment.