Skip to content

Commit

Permalink
Add: WGMMA synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Feb 10, 2025
1 parent 733cbac commit 0207843
Show file tree
Hide file tree
Showing 2 changed files with 516 additions and 349 deletions.
354 changes: 5 additions & 349 deletions less_slow_sm80.ptx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/**
* less_slow_sm90a.ptx
* less_slow_sm80.ptx
*
* Micro-kernels for building a performance-first mindset for CUDA-capable
* GPUs using Parallel Thread eXecution (PTX) Intermediate Representation (IR)
* for for Hopper-generation Nvidia GPUs and newer.
* for for Ampere-generation Nvidia GPUs with Warp-level MMA (WMMA).
*
* ? You should start at `less_slow.cu` before reading this file.
* ? You should start at `less_slow_sm70.ptx` before reading this file.
Expand All @@ -13,7 +13,7 @@
* You can validate this file by asking the Nvidia PTX Assembler to compile it
* to `.cubin` for some target architecture:
*
* $ ptxas -o less_slow_from_ptx.cubin -arch=sm_90a less_slow_sm90a.ptx
* $ ptxas -o less_slow_from_ptx.cubin -arch=sm_80 less_slow_sm80.ptx
* $ cuobjdump -sass less_slow_from_ptx.cubin | grep -i mma
*
* Assuming how aggressively NVCC unrolls loops and the number of kernels in
Expand All @@ -24,8 +24,8 @@
* $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \
* $ sort -u
*/
.version 8.0 // PTX version 8.0 for Hopper GPUs
.target sm_90a // Target architecture (SM_90a - Hopper GPUs)
.version 7.0 // PTX version 7.0 for Ampere GPUs
.target sm_80 // Target architecture (SM_80 - Ampere GPUs)
.address_size 64 // 64-bit addressing

/**
Expand Down Expand Up @@ -283,347 +283,3 @@ loop_exit:
st.global.volatile.f32 [dummy_sink_f32+12], accum3;
ret;
}

/**
* The instruction syntax for Warp-Group asynchronous instructions is very
* different, as at least one of the operand matrices has to be in shared
* memory (not registers). It's documented as in 2 variants:
*
* wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
* d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b;
* wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32
* d, a, b-desc, scale-d, imm-scale-a, imm-scale-b;
*
* There is no "C" matrix involved at all, we are computing `D = A * B + D`.
* The A and B matrix descriptors are the properties of the matrix in shared
* memory. It is a 64-bit value contained with the following layout:
*
* - 14 bits [0; 13]: start address
* - 14 bits [16; 29]: leading dimension byte offset
* - 14 bits [32; 45]: stride dimension byte offset
* - 3 bits [49; 51]: matrix base offset, valid only for "swizzling"
* - 2 bits [62; 63]: "swizzling" mode
*
* Swizzling defines the order of the elements and can have 4 possible values:
*
* 0: no "swizzling" at all
* 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern
* 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern
* 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern
*
* The list of supported shapes is exhausting:
*
* .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8,
* .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8,
* .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8,
* .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8,
* .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8,
* .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8,
* .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8,
* .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8
*
* The `scale` parameters can be used to either negate the inputs, or disable
* additive bias accumulation in the output.
*/
.visible .entry tops_tf32f32_sm90tc_m64n16k8_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of this MMA
.reg .f32 accum<8>;

// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;

// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;

// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;

// Set up loop counter and loop limit
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;

// Zero-initialize the accumulator registers
mov.f32 accum0, 0.0;
mov.f32 accum1, 0.0;
mov.f32 accum2, 0.0;
mov.f32 accum3, 0.0;
mov.f32 accum4, 0.0;
mov.f32 accum5, 0.0;
mov.f32 accum6, 0.0;
mov.f32 accum7, 0.0;

// Initialize matrix descriptors with arbitrary placeholder values
mov.u64 desc_a, 0x0000000000000000;
mov.u64 desc_b, 0x0000000000000000;

// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;

// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;

wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7 },
desc_a,
desc_b,
0, -1, -1;

// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;

// Branch back to the beginning of the loop
bra loop_start;

loop_exit:
// Commit all prior uncommitted operations to the group and wait!
wgmma.commit_group.sync.aligned;
wgmma.wait_group.sync.aligned 0;

// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.f32 [dummy_sink_f32], accum0;
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
st.global.volatile.f32 [dummy_sink_f32+8], accum2;
st.global.volatile.f32 [dummy_sink_f32+12], accum3;
ret;
}

/**
* This results in massive performance gains on Hopper:
* - 16x16x8 MMA computed by individual warps: 74 T
* - 64x16x8 WGMMA computed by four warps together: 300 T
*
* Will it get even better with larger matrices if we scale the second
* dimension from 16 to 256? It would require 128 accumulators.
*/

.visible .entry tops_tf32f32_sm90tc_m64n256k8_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of this MMA
.reg .f32 accum<128>;

// Descriptors for matrix A and matrix B operands
.reg .b64 desc_a, desc_b;

// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;

// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;

// Set up loop counter and loop limit to fill accumulators
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;

// Zero-initialize the accumulator registers:
mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0;
mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0;
mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0;
mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0;
mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0;
mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0;
mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0;
mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0;
mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0;
mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0;
mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0;
mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0;
mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0;
mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0;
mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0;
mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0;
mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0;
mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0;
mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0;
mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0;
mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0;
mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0;
mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0;
mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0;
mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0;
mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0;
mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0;
mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0;
mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0;
mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0;
mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0;
mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0;

// Initialize matrix descriptors with arbitrary placeholder values
mov.u64 desc_a, 0x0000000000000000;
mov.u64 desc_b, 0x0000000000000000;

// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;

// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;

wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
0, -1, -1;

// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;

// Branch back to the beginning of the loop
bra loop_start;

loop_exit:
// Commit all prior uncommitted operations to the group and wait!
wgmma.commit_group.sync.aligned;
wgmma.wait_group.sync.aligned 0;

// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.f32 [dummy_sink_f32], accum0;
st.global.volatile.f32 [dummy_sink_f32+4], accum1;
st.global.volatile.f32 [dummy_sink_f32+8], accum2;
st.global.volatile.f32 [dummy_sink_f32+12], accum3;
ret;
}

/**
* This results in massive performance gains on Hopper:
* - 16x16x8 MMA computed by individual warps: 74 T
* - 64x16x8 WGMMA computed by four warps together: 300 T
* - 64x256x8 WGMMA computed by four warps together: 4.7 P ?!
*
* There are also "structured-sparse" variants of those instructions, in case
* half of our entries are zeros! Those, however, simply expand the last
* dimension by 2x, making the instructions no more usable for small matrices.
*/

.visible .entry tops_b1i32and_sm90tc_m64n256k256_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of the MMA operation
.reg .s32 accum<128>;

// Descriptors for matrix A and matrix B operands (in shared memory)
.reg .b64 desc_a, desc_b;

// General-purpose registers for loop control
.reg .b32 loop_counter, loop_limit;

// Predicate registers for conditional branching (loop exit) and scale flag
.reg .pred exit_predicate, scale_d;

// Set up loop counter and loop limit
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;

// Zero-initialize the accumulators, as registers may contain noise
mov.s32 accum0, 0; mov.s32 accum1, 0; mov.s32 accum2, 0; mov.s32 accum3, 0;
mov.s32 accum4, 0; mov.s32 accum5, 0; mov.s32 accum6, 0; mov.s32 accum7, 0;
mov.s32 accum8, 0; mov.s32 accum9, 0; mov.s32 accum10, 0; mov.s32 accum11, 0;
mov.s32 accum12, 0; mov.s32 accum13, 0; mov.s32 accum14, 0; mov.s32 accum15, 0;
mov.s32 accum16, 0; mov.s32 accum17, 0; mov.s32 accum18, 0; mov.s32 accum19, 0;
mov.s32 accum20, 0; mov.s32 accum21, 0; mov.s32 accum22, 0; mov.s32 accum23, 0;
mov.s32 accum24, 0; mov.s32 accum25, 0; mov.s32 accum26, 0; mov.s32 accum27, 0;
mov.s32 accum28, 0; mov.s32 accum29, 0; mov.s32 accum30, 0; mov.s32 accum31, 0;
mov.s32 accum32, 0; mov.s32 accum33, 0; mov.s32 accum34, 0; mov.s32 accum35, 0;
mov.s32 accum36, 0; mov.s32 accum37, 0; mov.s32 accum38, 0; mov.s32 accum39, 0;
mov.s32 accum40, 0; mov.s32 accum41, 0; mov.s32 accum42, 0; mov.s32 accum43, 0;
mov.s32 accum44, 0; mov.s32 accum45, 0; mov.s32 accum46, 0; mov.s32 accum47, 0;
mov.s32 accum48, 0; mov.s32 accum49, 0; mov.s32 accum50, 0; mov.s32 accum51, 0;
mov.s32 accum52, 0; mov.s32 accum53, 0; mov.s32 accum54, 0; mov.s32 accum55, 0;
mov.s32 accum56, 0; mov.s32 accum57, 0; mov.s32 accum58, 0; mov.s32 accum59, 0;
mov.s32 accum60, 0; mov.s32 accum61, 0; mov.s32 accum62, 0; mov.s32 accum63, 0;
mov.s32 accum64, 0; mov.s32 accum65, 0; mov.s32 accum66, 0; mov.s32 accum67, 0;
mov.s32 accum68, 0; mov.s32 accum69, 0; mov.s32 accum70, 0; mov.s32 accum71, 0;
mov.s32 accum72, 0; mov.s32 accum73, 0; mov.s32 accum74, 0; mov.s32 accum75, 0;
mov.s32 accum76, 0; mov.s32 accum77, 0; mov.s32 accum78, 0; mov.s32 accum79, 0;
mov.s32 accum80, 0; mov.s32 accum81, 0; mov.s32 accum82, 0; mov.s32 accum83, 0;
mov.s32 accum84, 0; mov.s32 accum85, 0; mov.s32 accum86, 0; mov.s32 accum87, 0;
mov.s32 accum88, 0; mov.s32 accum89, 0; mov.s32 accum90, 0; mov.s32 accum91, 0;
mov.s32 accum92, 0; mov.s32 accum93, 0; mov.s32 accum94, 0; mov.s32 accum95, 0;
mov.s32 accum96, 0; mov.s32 accum97, 0; mov.s32 accum98, 0; mov.s32 accum99, 0;
mov.s32 accum100, 0; mov.s32 accum101, 0; mov.s32 accum102, 0; mov.s32 accum103, 0;
mov.s32 accum104, 0; mov.s32 accum105, 0; mov.s32 accum106, 0; mov.s32 accum107, 0;
mov.s32 accum108, 0; mov.s32 accum109, 0; mov.s32 accum110, 0; mov.s32 accum111, 0;
mov.s32 accum112, 0; mov.s32 accum113, 0; mov.s32 accum114, 0; mov.s32 accum115, 0;
mov.s32 accum116, 0; mov.s32 accum117, 0; mov.s32 accum118, 0; mov.s32 accum119, 0;
mov.s32 accum120, 0; mov.s32 accum121, 0; mov.s32 accum122, 0; mov.s32 accum123, 0;
mov.s32 accum124, 0; mov.s32 accum125, 0; mov.s32 accum126, 0; mov.s32 accum127, 0;

// Initialize matrix descriptors with arbitrary placeholder values.
// In practice, these would be set to point to shared-memory regions containing your matrices.
mov.u64 desc_a, 0x0000000000000000;
mov.u64 desc_b, 0x0000000000000000;

// Initialize scale flag (controls operand scaling or additive bias behavior)
mov.pred scale_d, 1;

// Enforce the ordered for Warp-Group instructions
wgmma.fence.sync.aligned;

// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_exit;

wgmma.mma_async.sync.aligned.m64n256k256.s32.b1.b1.and.popc
{ accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7,
accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15,
accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23,
accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31,
accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39,
accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47,
accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55,
accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63,
accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71,
accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79,
accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87,
accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95,
accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103,
accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111,
accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119,
accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 },
desc_a,
desc_b,
scale_d;

// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;

// Branch back to the beginning of the loop
bra loop_start;

loop_exit:
// Commit all prior uncommitted operations to the group and wait!
wgmma.commit_group.sync.aligned;
wgmma.wait_group.sync.aligned 0;

// Use volatile stores to force the accumulator values to be written out.
// This dummy write (to a global variable) makes the work observable and
// prevents the multiplication pipeline from being optimized out.
st.global.volatile.s32 [dummy_sink_s32], accum0;
st.global.volatile.s32 [dummy_sink_s32+4], accum1;
st.global.volatile.s32 [dummy_sink_s32+8], accum2;
st.global.volatile.s32 [dummy_sink_s32+12], accum3;
ret;
}
Loading

0 comments on commit 0207843

Please sign in to comment.