-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable preshuffled mixed dtype Cutlass Gemm (#3722)
Summary: WIP to enable new optimized preshuffled fp8xint4 gemm. While the example compiles and runs, it runs into a variety of problems. The outputs are either completely incorrect, contain NaNs, or the kernel hits an Illegal Memory Access. I'm not yet sure why. Differential Revision: D69955197
- Loading branch information
1 parent
3209bb4
commit bbca782
Showing
4 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
233 changes: 233 additions & 0 deletions
233
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include "cutlass/cutlass.h" | ||
|
||
#include "cute/tensor.hpp" | ||
#include "cutlass/epilogue/collective/collective_builder.hpp" | ||
#include "cutlass/epilogue/collective/default_epilogue.hpp" | ||
#include "cutlass/epilogue/thread/linear_combination.h" | ||
#include "cutlass/gemm/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
#include "cutlass/gemm/dispatch_policy.hpp" | ||
#include "cutlass/gemm/kernel/gemm_universal.hpp" | ||
#include "cutlass/tensor_ref.h" | ||
|
||
#include "cutlass/util/mixed_dtype_utils.hpp" | ||
#include "cutlass/util/packed_stride.hpp" | ||
|
||
namespace fbgemm_gpu { | ||
|
||
at::Tensor f8i4bf16_shuffled( | ||
at::Tensor XQ, | ||
at::Tensor WQ, | ||
at::Tensor x_scale, | ||
at::Tensor w_scale) { | ||
// Get shape information from input tensors. | ||
int M = XQ.size(0); | ||
int K = XQ.size(1); | ||
int N = WQ.size(0); | ||
// Make sure w_scale is in proper format. | ||
TORCH_CHECK( | ||
w_scale.size(2) == 8, | ||
"Weights and scales must be prepacked with preshuffle_i4."); | ||
int num_groups = w_scale.size(1); | ||
int group_size = K / num_groups; | ||
// Allocate output. | ||
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); | ||
|
||
// Define input types. | ||
using MmaType = cutlass::float_e4m3_t; | ||
using QuantType = cutlass::int4b_t; | ||
constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value; | ||
|
||
// A Matrix configuration. | ||
using ElementA = MmaType; | ||
using LayoutA = cutlass::layout::RowMajor; | ||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; | ||
|
||
// B Matrix Configuration. | ||
using ElementB = QuantType; | ||
using LayoutB = cutlass::layout::ColumnMajor; | ||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; | ||
|
||
// We need to manually swap and transpose inputs. Unclear how required this is | ||
// though. | ||
using LayoutA_Transpose = | ||
typename cutlass::layout::LayoutTranspose<LayoutA>::type; | ||
using LayoutB_Transpose = | ||
typename cutlass::layout::LayoutTranspose<LayoutB>::type; | ||
|
||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; | ||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>; | ||
|
||
// Define layout for shuffled weight tensor. | ||
using LayoutAtomQuant = | ||
decltype(cutlass::compute_memory_reordering_atom<MmaType>()); | ||
using LayoutB_Reordered = decltype(cute::tile_to_shape( | ||
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{})); | ||
|
||
using ElementScale = MmaType; | ||
|
||
// Output Matrix configuration. | ||
using ElementC = cutlass::bfloat16_t; | ||
using LayoutC = cutlass::layout::RowMajor; | ||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; | ||
|
||
// Core kernel configurations | ||
using ElementAccumulator = float; | ||
using ElementCompute = float; | ||
using ArchTag = cutlass::arch::Sm90; | ||
using OperatorClass = cutlass::arch::OpClassTensorOp; | ||
// TODO tune these shapes. | ||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::Int<TileShapeK>>; | ||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>; | ||
// TODO Should we use fast accum here? | ||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; | ||
// Might be the only epilogue schedule that supports swap + transpose. | ||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; | ||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | ||
|
||
// Define EVT for rowwise scaling. | ||
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast< | ||
0, | ||
TileShape, | ||
ElementAccumulator, | ||
ElementAccumulator, | ||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>; | ||
|
||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | ||
|
||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute< | ||
cutlass::multiplies, | ||
ElementC, // First stage output type. | ||
ElementAccumulator, // First stage input types. | ||
cutlass::FloatRoundStyle::round_to_nearest>; | ||
|
||
using EpilogueEVT = | ||
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>; | ||
|
||
using CollectiveEpilogue = | ||
typename cutlass::epilogue::collective::CollectiveBuilder< | ||
cutlass::arch::Sm90, | ||
cutlass::arch::OpClassTensorOp, | ||
TileShape, | ||
ClusterShape, | ||
EpilogueTileType, | ||
ElementAccumulator, | ||
ElementAccumulator, | ||
ElementC, | ||
typename cutlass::layout::LayoutTranspose<LayoutC>::type, | ||
AlignmentC, | ||
ElementC, | ||
typename cutlass::layout::LayoutTranspose<LayoutC>::type, | ||
AlignmentC, | ||
EpilogueSchedule, | ||
EpilogueEVT>::CollectiveOp; | ||
|
||
using CollectiveMainloopShuffled = | ||
typename cutlass::gemm::collective::CollectiveBuilder< | ||
ArchTag, | ||
OperatorClass, | ||
cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, | ||
LayoutB_Reordered, | ||
AlignmentB, | ||
ElementA, | ||
LayoutA_Transpose, | ||
AlignmentA, | ||
ElementAccumulator, | ||
TileShape, | ||
ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||
sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
KernelSchedule>::CollectiveOp; | ||
|
||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< | ||
cute::Shape<int, int, int, int>, | ||
CollectiveMainloopShuffled, | ||
CollectiveEpilogue>; | ||
|
||
using GemmShuffled = | ||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>; | ||
|
||
using StrideC = typename GemmKernelShuffled::StrideC; | ||
|
||
/// Initialization | ||
auto shape_B = cute::make_shape(N, K, 1); | ||
StrideA stride_A = | ||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); | ||
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); | ||
StrideC stride_C = | ||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1)); | ||
LayoutB_Reordered layout_B_reordered = | ||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B); | ||
using StrideS = typename CollectiveMainloopShuffled::StrideScale; | ||
StrideS stride_S = cutlass::make_cute_packed_stride( | ||
StrideS{}, cute::make_shape(N, num_groups, 1)); | ||
|
||
// Define Gemm arguments. | ||
typename GemmShuffled::Arguments arguments{ | ||
cutlass::gemm::GemmUniversalMode::kGemm, | ||
{N, M, K, 1}, | ||
{reinterpret_cast<ElementB*>(WQ.data_ptr()), | ||
layout_B_reordered, | ||
reinterpret_cast<ElementA*>(XQ.data_ptr()), | ||
stride_A, | ||
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()), | ||
stride_S, | ||
group_size}, | ||
{{}, | ||
reinterpret_cast<ElementC*>(Y.data_ptr()), | ||
stride_C, | ||
reinterpret_cast<ElementC*>(Y.data_ptr()), | ||
stride_C}}; | ||
|
||
arguments.epilogue.thread = { | ||
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale | ||
{}, // Accumulator | ||
{}, // Multiplies | ||
}; | ||
|
||
// Launch the workload. | ||
GemmShuffled gemm; | ||
|
||
// Using the arguments, query for extra workspace required for matrix | ||
// multiplication computation | ||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments); | ||
|
||
// Allocate workspace memory | ||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||
|
||
// Check the problem size is supported or not | ||
cutlass::Status status = gemm.can_implement(arguments); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot implement"); | ||
} | ||
|
||
// Initialize CUTLASS kernel with arguments and workspace pointer | ||
status = gemm.initialize(arguments, workspace.get()); | ||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error("cutlass cannot initialize"); | ||
} | ||
|
||
status = gemm(at::cuda::getCurrentCUDAStream()); | ||
|
||
if (status != cutlass::Status::kSuccess) { | ||
throw std::runtime_error( | ||
std::string("cutlass cannot run") + | ||
cutlass::cutlassGetStatusString(status)); | ||
} | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
|
||
return Y; | ||
} | ||
|
||
} // namespace fbgemm_gpu |
75 changes: 75 additions & 0 deletions
75
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/mixed_dtype_utils.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include "cute/layout.hpp" | ||
#include "cutlass/detail/layout.hpp" | ||
#include "cutlass/layout/matrix.h" | ||
#include "cutlass/util/mixed_dtype_utils.hpp" | ||
|
||
namespace fbgemm_gpu { | ||
|
||
std::tuple<at::Tensor, at::Tensor> preshuffle_i4( | ||
at::Tensor WQ, | ||
at::Tensor w_scale) { | ||
// Check that w_scale is proper type. if not, quantize it. | ||
if (w_scale.dtype() != at::kFloat8_e4m3fn) { | ||
TORCH_WARN( | ||
"Weight scale must be FP8 for preshuffled GEMM. Performing downcasting."); | ||
w_scale = w_scale.to(WQ.options().dtype(at::kFloat8_e4m3fn)); | ||
} | ||
// Start by allocating space for shuffled tensors. | ||
at::Tensor WQ_shuffled = at::empty_like(WQ); | ||
// Packed scale contains 8 lookup values for each original scale element. | ||
at::Tensor w_scale_packed = | ||
at::empty({w_scale.size(0), w_scale.size(1), 8}, w_scale.options()); | ||
// WQ has two int4 values packed into each int8 dtype, so the size | ||
// is larger than it seems. | ||
size_t WQ_size = 2 * WQ.numel(); | ||
// Encode weights to enable efficient lookup. | ||
cutlass::unified_encode_int4b( | ||
reinterpret_cast<cutlass::int4b_t*>(WQ.data_ptr()), | ||
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()), | ||
WQ_size); | ||
|
||
size_t w_scale_size = w_scale.numel(); | ||
cutlass::pack_scale_fp8( | ||
reinterpret_cast<cutlass::float_e4m3_t*>(w_scale.data_ptr()), | ||
reinterpret_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( | ||
w_scale_packed.data_ptr()), | ||
w_scale_size); | ||
|
||
// Next we need to shuffle B. To do this, we define a few helper objects. | ||
const int N = WQ.size(0); | ||
const int K = 2 * WQ.size(1); | ||
auto shape_B = cute::make_shape(N, K, 1); | ||
using LayoutB = cutlass::layout::ColumnMajor; | ||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>; | ||
using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom< | ||
cutlass::float_e4m3_t>()); | ||
using LayoutB_Reordered = decltype(cute::tile_to_shape( | ||
LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{})); | ||
StrideB stride_B; | ||
auto layout_B = make_layout(shape_B, stride_B); | ||
LayoutB_Reordered layout_B_reordered = | ||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B); | ||
; | ||
|
||
// Now we're ready to reorder the tensor into proper layout. | ||
cutlass::reorder_tensor( | ||
reinterpret_cast<cutlass::int4b_t*>(WQ_shuffled.data_ptr()), | ||
layout_B, | ||
layout_B_reordered); | ||
|
||
// Tensors should now be preshuffled and ready for use. | ||
return {WQ_shuffled, w_scale_packed}; | ||
} | ||
|
||
} // namespace fbgemm_gpu |
Oops, something went wrong.