From 02078432f9061717836e7bab95ba990fc8eceb10 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Feb 2025 12:07:47 +0000 Subject: [PATCH] Add: WGMMA synchronization --- less_slow_sm80.ptx | 354 +----------------------------- less_slow_sm90a.ptx | 511 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 516 insertions(+), 349 deletions(-) create mode 100644 less_slow_sm90a.ptx diff --git a/less_slow_sm80.ptx b/less_slow_sm80.ptx index 273b2d0..051f277 100644 --- a/less_slow_sm80.ptx +++ b/less_slow_sm80.ptx @@ -1,9 +1,9 @@ /** - * less_slow_sm90a.ptx + * less_slow_sm80.ptx * * Micro-kernels for building a performance-first mindset for CUDA-capable * GPUs using Parallel Thread eXecution (PTX) Intermediate Representation (IR) - * for for Hopper-generation Nvidia GPUs and newer. + * for for Ampere-generation Nvidia GPUs with Warp-level MMA (WMMA). * * ? You should start at `less_slow.cu` before reading this file. * ? You should start at `less_slow_sm70.ptx` before reading this file. @@ -13,7 +13,7 @@ * You can validate this file by asking the Nvidia PTX Assembler to compile it * to `.cubin` for some target architecture: * - * $ ptxas -o less_slow_from_ptx.cubin -arch=sm_90a less_slow_sm90a.ptx + * $ ptxas -o less_slow_from_ptx.cubin -arch=sm_80 less_slow_sm80.ptx * $ cuobjdump -sass less_slow_from_ptx.cubin | grep -i mma * * Assuming how aggressively NVCC unrolls loops and the number of kernels in @@ -24,8 +24,8 @@ * $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \ * $ sort -u */ -.version 8.0 // PTX version 8.0 for Hopper GPUs -.target sm_90a // Target architecture (SM_90a - Hopper GPUs) +.version 7.0 // PTX version 7.0 for Ampere GPUs +.target sm_80 // Target architecture (SM_80 - Ampere GPUs) .address_size 64 // 64-bit addressing /** @@ -283,347 +283,3 @@ loop_exit: st.global.volatile.f32 [dummy_sink_f32+12], accum3; ret; } - -/** - * The instruction syntax for Warp-Group asynchronous instructions is very - * different, as at least one of the operand matrices has to be in shared - * memory (not registers). It's documented as in 2 variants: - * - * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 - * d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b; - * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 - * d, a, b-desc, scale-d, imm-scale-a, imm-scale-b; - * - * There is no "C" matrix involved at all, we are computing `D = A * B + D`. - * The A and B matrix descriptors are the properties of the matrix in shared - * memory. It is a 64-bit value contained with the following layout: - * - * - 14 bits [0; 13]: start address - * - 14 bits [16; 29]: leading dimension byte offset - * - 14 bits [32; 45]: stride dimension byte offset - * - 3 bits [49; 51]: matrix base offset, valid only for "swizzling" - * - 2 bits [62; 63]: "swizzling" mode - * - * Swizzling defines the order of the elements and can have 4 possible values: - * - * 0: no "swizzling" at all - * 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern - * 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern - * 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern - * - * The list of supported shapes is exhausting: - * - * .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8, - * .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8, - * .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8, - * .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8, - * .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8, - * .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8, - * .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8, - * .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8 - * - * The `scale` parameters can be used to either negate the inputs, or disable - * additive bias accumulation in the output. - */ -.visible .entry tops_tf32f32_sm90tc_m64n16k8_loop128_ptx_kernel() -{ - // Accumulator registers used for both input and output of this MMA - .reg .f32 accum<8>; - - // Descriptors for matrix A and matrix B operands - .reg .b64 desc_a, desc_b; - - // General-purpose registers for loop control - .reg .b32 loop_counter, loop_limit; - - // Predicate register for conditional branching (loop exit) - .reg .pred exit_predicate; - - // Set up loop counter and loop limit - mov.u32 loop_counter, 0; - mov.u32 loop_limit, 128; - - // Zero-initialize the accumulator registers - mov.f32 accum0, 0.0; - mov.f32 accum1, 0.0; - mov.f32 accum2, 0.0; - mov.f32 accum3, 0.0; - mov.f32 accum4, 0.0; - mov.f32 accum5, 0.0; - mov.f32 accum6, 0.0; - mov.f32 accum7, 0.0; - - // Initialize matrix descriptors with arbitrary placeholder values - mov.u64 desc_a, 0x0000000000000000; - mov.u64 desc_b, 0x0000000000000000; - - // Enforce the ordered for Warp-Group instructions - wgmma.fence.sync.aligned; - - // The main loop will repeat for 128 iterations -loop_start: - setp.ge.u32 exit_predicate, loop_counter, loop_limit; - @exit_predicate bra loop_exit; - - wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 - { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7 }, - desc_a, - desc_b, - 0, -1, -1; - - // Increment the loop counter - add.u32 loop_counter, loop_counter, 1; - - // Branch back to the beginning of the loop - bra loop_start; - -loop_exit: - // Commit all prior uncommitted operations to the group and wait! - wgmma.commit_group.sync.aligned; - wgmma.wait_group.sync.aligned 0; - - // Use volatile stores to force the accumulator values to be written out. - // This dummy write (to a global variable) makes the work observable and - // prevents the multiplication pipeline from being optimized out. - st.global.volatile.f32 [dummy_sink_f32], accum0; - st.global.volatile.f32 [dummy_sink_f32+4], accum1; - st.global.volatile.f32 [dummy_sink_f32+8], accum2; - st.global.volatile.f32 [dummy_sink_f32+12], accum3; - ret; -} - -/** - * This results in massive performance gains on Hopper: - * - 16x16x8 MMA computed by individual warps: 74 T - * - 64x16x8 WGMMA computed by four warps together: 300 T - * - * Will it get even better with larger matrices if we scale the second - * dimension from 16 to 256? It would require 128 accumulators. - */ - -.visible .entry tops_tf32f32_sm90tc_m64n256k8_loop128_ptx_kernel() -{ - // Accumulator registers used for both input and output of this MMA - .reg .f32 accum<128>; - - // Descriptors for matrix A and matrix B operands - .reg .b64 desc_a, desc_b; - - // General-purpose registers for loop control - .reg .b32 loop_counter, loop_limit; - - // Predicate register for conditional branching (loop exit) - .reg .pred exit_predicate; - - // Set up loop counter and loop limit to fill accumulators - mov.u32 loop_counter, 0; - mov.u32 loop_limit, 128; - - // Zero-initialize the accumulator registers: - mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0; - mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0; - mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0; - mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0; - mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0; - mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0; - mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0; - mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0; - mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0; - mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0; - mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0; - mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0; - mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0; - mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0; - mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0; - mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0; - mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0; - mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0; - mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0; - mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0; - mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0; - mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0; - mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0; - mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0; - mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0; - mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0; - mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0; - mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0; - mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0; - mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0; - mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0; - mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0; - - // Initialize matrix descriptors with arbitrary placeholder values - mov.u64 desc_a, 0x0000000000000000; - mov.u64 desc_b, 0x0000000000000000; - - // Enforce the ordered for Warp-Group instructions - wgmma.fence.sync.aligned; - - // The main loop will repeat for 128 iterations -loop_start: - setp.ge.u32 exit_predicate, loop_counter, loop_limit; - @exit_predicate bra loop_exit; - - wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 - { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7, - accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15, - accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23, - accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31, - accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39, - accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47, - accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55, - accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63, - accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71, - accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79, - accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87, - accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95, - accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103, - accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111, - accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119, - accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 }, - desc_a, - desc_b, - 0, -1, -1; - - // Increment the loop counter - add.u32 loop_counter, loop_counter, 1; - - // Branch back to the beginning of the loop - bra loop_start; - -loop_exit: - // Commit all prior uncommitted operations to the group and wait! - wgmma.commit_group.sync.aligned; - wgmma.wait_group.sync.aligned 0; - - // Use volatile stores to force the accumulator values to be written out. - // This dummy write (to a global variable) makes the work observable and - // prevents the multiplication pipeline from being optimized out. - st.global.volatile.f32 [dummy_sink_f32], accum0; - st.global.volatile.f32 [dummy_sink_f32+4], accum1; - st.global.volatile.f32 [dummy_sink_f32+8], accum2; - st.global.volatile.f32 [dummy_sink_f32+12], accum3; - ret; -} - -/** - * This results in massive performance gains on Hopper: - * - 16x16x8 MMA computed by individual warps: 74 T - * - 64x16x8 WGMMA computed by four warps together: 300 T - * - 64x256x8 WGMMA computed by four warps together: 4.7 P ?! - * - * There are also "structured-sparse" variants of those instructions, in case - * half of our entries are zeros! Those, however, simply expand the last - * dimension by 2x, making the instructions no more usable for small matrices. - */ - -.visible .entry tops_b1i32and_sm90tc_m64n256k256_loop128_ptx_kernel() -{ - // Accumulator registers used for both input and output of the MMA operation - .reg .s32 accum<128>; - - // Descriptors for matrix A and matrix B operands (in shared memory) - .reg .b64 desc_a, desc_b; - - // General-purpose registers for loop control - .reg .b32 loop_counter, loop_limit; - - // Predicate registers for conditional branching (loop exit) and scale flag - .reg .pred exit_predicate, scale_d; - - // Set up loop counter and loop limit - mov.u32 loop_counter, 0; - mov.u32 loop_limit, 128; - - // Zero-initialize the accumulators, as registers may contain noise - mov.s32 accum0, 0; mov.s32 accum1, 0; mov.s32 accum2, 0; mov.s32 accum3, 0; - mov.s32 accum4, 0; mov.s32 accum5, 0; mov.s32 accum6, 0; mov.s32 accum7, 0; - mov.s32 accum8, 0; mov.s32 accum9, 0; mov.s32 accum10, 0; mov.s32 accum11, 0; - mov.s32 accum12, 0; mov.s32 accum13, 0; mov.s32 accum14, 0; mov.s32 accum15, 0; - mov.s32 accum16, 0; mov.s32 accum17, 0; mov.s32 accum18, 0; mov.s32 accum19, 0; - mov.s32 accum20, 0; mov.s32 accum21, 0; mov.s32 accum22, 0; mov.s32 accum23, 0; - mov.s32 accum24, 0; mov.s32 accum25, 0; mov.s32 accum26, 0; mov.s32 accum27, 0; - mov.s32 accum28, 0; mov.s32 accum29, 0; mov.s32 accum30, 0; mov.s32 accum31, 0; - mov.s32 accum32, 0; mov.s32 accum33, 0; mov.s32 accum34, 0; mov.s32 accum35, 0; - mov.s32 accum36, 0; mov.s32 accum37, 0; mov.s32 accum38, 0; mov.s32 accum39, 0; - mov.s32 accum40, 0; mov.s32 accum41, 0; mov.s32 accum42, 0; mov.s32 accum43, 0; - mov.s32 accum44, 0; mov.s32 accum45, 0; mov.s32 accum46, 0; mov.s32 accum47, 0; - mov.s32 accum48, 0; mov.s32 accum49, 0; mov.s32 accum50, 0; mov.s32 accum51, 0; - mov.s32 accum52, 0; mov.s32 accum53, 0; mov.s32 accum54, 0; mov.s32 accum55, 0; - mov.s32 accum56, 0; mov.s32 accum57, 0; mov.s32 accum58, 0; mov.s32 accum59, 0; - mov.s32 accum60, 0; mov.s32 accum61, 0; mov.s32 accum62, 0; mov.s32 accum63, 0; - mov.s32 accum64, 0; mov.s32 accum65, 0; mov.s32 accum66, 0; mov.s32 accum67, 0; - mov.s32 accum68, 0; mov.s32 accum69, 0; mov.s32 accum70, 0; mov.s32 accum71, 0; - mov.s32 accum72, 0; mov.s32 accum73, 0; mov.s32 accum74, 0; mov.s32 accum75, 0; - mov.s32 accum76, 0; mov.s32 accum77, 0; mov.s32 accum78, 0; mov.s32 accum79, 0; - mov.s32 accum80, 0; mov.s32 accum81, 0; mov.s32 accum82, 0; mov.s32 accum83, 0; - mov.s32 accum84, 0; mov.s32 accum85, 0; mov.s32 accum86, 0; mov.s32 accum87, 0; - mov.s32 accum88, 0; mov.s32 accum89, 0; mov.s32 accum90, 0; mov.s32 accum91, 0; - mov.s32 accum92, 0; mov.s32 accum93, 0; mov.s32 accum94, 0; mov.s32 accum95, 0; - mov.s32 accum96, 0; mov.s32 accum97, 0; mov.s32 accum98, 0; mov.s32 accum99, 0; - mov.s32 accum100, 0; mov.s32 accum101, 0; mov.s32 accum102, 0; mov.s32 accum103, 0; - mov.s32 accum104, 0; mov.s32 accum105, 0; mov.s32 accum106, 0; mov.s32 accum107, 0; - mov.s32 accum108, 0; mov.s32 accum109, 0; mov.s32 accum110, 0; mov.s32 accum111, 0; - mov.s32 accum112, 0; mov.s32 accum113, 0; mov.s32 accum114, 0; mov.s32 accum115, 0; - mov.s32 accum116, 0; mov.s32 accum117, 0; mov.s32 accum118, 0; mov.s32 accum119, 0; - mov.s32 accum120, 0; mov.s32 accum121, 0; mov.s32 accum122, 0; mov.s32 accum123, 0; - mov.s32 accum124, 0; mov.s32 accum125, 0; mov.s32 accum126, 0; mov.s32 accum127, 0; - - // Initialize matrix descriptors with arbitrary placeholder values. - // In practice, these would be set to point to shared-memory regions containing your matrices. - mov.u64 desc_a, 0x0000000000000000; - mov.u64 desc_b, 0x0000000000000000; - - // Initialize scale flag (controls operand scaling or additive bias behavior) - mov.pred scale_d, 1; - - // Enforce the ordered for Warp-Group instructions - wgmma.fence.sync.aligned; - - // The main loop will repeat for 128 iterations -loop_start: - setp.ge.u32 exit_predicate, loop_counter, loop_limit; - @exit_predicate bra loop_exit; - - wgmma.mma_async.sync.aligned.m64n256k256.s32.b1.b1.and.popc - { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7, - accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15, - accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23, - accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31, - accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39, - accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47, - accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55, - accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63, - accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71, - accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79, - accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87, - accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95, - accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103, - accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111, - accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119, - accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 }, - desc_a, - desc_b, - scale_d; - - // Increment the loop counter - add.u32 loop_counter, loop_counter, 1; - - // Branch back to the beginning of the loop - bra loop_start; - -loop_exit: - // Commit all prior uncommitted operations to the group and wait! - wgmma.commit_group.sync.aligned; - wgmma.wait_group.sync.aligned 0; - - // Use volatile stores to force the accumulator values to be written out. - // This dummy write (to a global variable) makes the work observable and - // prevents the multiplication pipeline from being optimized out. - st.global.volatile.s32 [dummy_sink_s32], accum0; - st.global.volatile.s32 [dummy_sink_s32+4], accum1; - st.global.volatile.s32 [dummy_sink_s32+8], accum2; - st.global.volatile.s32 [dummy_sink_s32+12], accum3; - ret; -} diff --git a/less_slow_sm90a.ptx b/less_slow_sm90a.ptx new file mode 100644 index 0000000..91af08e --- /dev/null +++ b/less_slow_sm90a.ptx @@ -0,0 +1,511 @@ +/** + * less_slow_sm90a.ptx + * + * Micro-kernels for building a performance-first mindset for CUDA-capable + * GPUs using Parallel Thread eXecution (PTX) Intermediate Representation (IR) + * for for Hopper-generation Nvidia GPUs with Warp-Group-level MMA (WGMMA). + * + * ? You should start at `less_slow.cu` before reading this file. + * ? You should start at `less_slow_sm70.ptx` before reading this file. + * ? Also read intro to PTX: https://docs.nvidia.com/cuda/parallel-thread-execution/ + * ? Check the PTX ISA: https://docs.nvidia.com/cuda/pdf/ptx_isa_8.5.pdf + * + * You can validate this file by asking the Nvidia PTX Assembler to compile it + * to `.cubin` for some target architecture: + * + * $ ptxas -o less_slow_from_ptx.cubin -arch=sm_90a less_slow_sm90a.ptx + * $ cuobjdump -sass less_slow_from_ptx.cubin | grep -i mma + * + * Assuming how aggressively NVCC unrolls loops and the number of kernels in + * this file, you may want to deduplicate them: + * + * $ cuobjdump -sass less_slow_from_ptx.cubin | grep -i mma | \ + * $ sed -r 's/\/\*[^*]+\*\///g' | \ + * $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \ + * $ sort -u + */ +.version 8.0 // PTX version 8.0 for Hopper GPUs +.target sm_90a // Target architecture (SM_90a - Hopper GPUs) +.address_size 64 // 64-bit addressing + +/** + * Let's define some global memory buffers, visible on both device and host + * side, to output multiplication results. + */ +.visible .global .align 8 .f64 dummy_sink_f64[32]; +.visible .global .align 4 .s32 dummy_sink_s32[32]; +.visible .global .align 4 .f32 dummy_sink_f32[32]; + +/** + * The instruction syntax for Warp-Group asynchronous instructions is very + * different, as at least one of the operand matrices has to be in shared + * memory (not registers). It's documented as in 2 variants: + * + * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 + * d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b; + * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 + * d, a, b-desc, scale-d, imm-scale-a, imm-scale-b; + * + * There is no "C" matrix involved at all, we are computing `D = A * B + D`. + * The A and B matrix descriptors are the properties of the matrix in shared + * memory. It is a 64-bit value contained with the following layout: + * + * - 14 bits [0; 13]: start address + * - 14 bits [16; 29]: leading dimension byte offset + * - 14 bits [32; 45]: stride dimension byte offset + * - 3 bits [49; 51]: matrix base offset, valid only for "swizzling" + * - 2 bits [62; 63]: "swizzling" mode + * + * The matrix layout in WGMMA can be normal or transposed, but its named + * differently. Non-Transposed for A and B is called K-Major. The Transposed + * variant is called M-Major for A and N-Major for B. + * + * The matrices in the shared memory are made up of one or more "swizzle + * layout atom". The exact layout of these swizzle atoms depends on the + * swizzling mode, swizzle-atomicity, and the leading dimension. + * + * Swizzling defines the order of the elements and can have 4 possible values: + * + * 0: no "swizzling" at all + * 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern + * 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern + * 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern + * + * The list of supported shapes is exhausting: + * + * .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8, + * .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8, + * .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8, + * .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8, + * .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8, + * .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8, + * .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8, + * .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8 + * + * The `imm-scale` parameters can be used to either negate the inputs, + * or disable additive bias accumulation in the output. + */ +.visible .entry tops_tf32f32_sm90tc_m64n16k8_loop128_ptx_kernel() +{ + // Accumulator registers used for both input and output of this MMA + .reg .f32 accum<8>; + + // Descriptors for matrix A and matrix B operands + .reg .b64 desc_a, desc_b; + + // TF32 variables will be stores in B32 slots in 2D arrays + .shared .b32 tile_a[64][8]; + .shared .b32 tile_b[8][16]; + + // Define the shape of M x K matrix A + mov.u64 desc_a, 0; // Start address + or.b64 desc_a, desc_a, 32<<16; // leading dimension + or.b64 desc_a, desc_a, 4<<32; // stride dimension + or.b64 desc_a, desc_a, 1<<46; // fixed constant + + // Define the shape of K x N matrix B + mov.u64 desc_b, 2048; // First matrix takes 2 KB + or.b64 desc_b, desc_b, 64<<16; // leading dimension + or.b64 desc_b, desc_b, 4<<32; // stride dimension + or.b64 desc_b, desc_b, 1<<46; // fixed constant + + + // General-purpose registers for loop control + .reg .b32 loop_counter, loop_limit; + + // Predicate register for conditional branching (loop exit) + .reg .pred exit_predicate; + + // Set up loop counter and loop limit + mov.u32 loop_counter, 0; + mov.u32 loop_limit, 1; + + // Zero-initialize the accumulator registers + mov.f32 accum0, 0.0; + mov.f32 accum1, 0.0; + mov.f32 accum2, 0.0; + mov.f32 accum3, 0.0; + mov.f32 accum4, 0.0; + mov.f32 accum5, 0.0; + mov.f32 accum6, 0.0; + mov.f32 accum7, 0.0; + + // Enforce the ordered for Warp-Group instructions + // wgmma.fence.sync.aligned; + + // The main loop will repeat for 128 iterations +loop_start: + setp.ge.u32 exit_predicate, loop_counter, loop_limit; + @exit_predicate bra loop_exit; + + wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 + { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7 }, + desc_a, + desc_b, + 1, 1, 1; + + // Increment the loop counter + add.u32 loop_counter, loop_counter, 1; + + // Branch back to the beginning of the loop + bra loop_start; + +loop_exit: + // Commit all prior uncommitted operations to the group and wait! + // wgmma.commit_group.sync.aligned; + wgmma.wait_group.sync.aligned 0; + + // Use volatile stores to force the accumulator values to be written out. + // This dummy write (to a global variable) makes the work observable and + // prevents the multiplication pipeline from being optimized out. + st.global.volatile.f32 [dummy_sink_f32], accum0; + st.global.volatile.f32 [dummy_sink_f32+4], accum1; + st.global.volatile.f32 [dummy_sink_f32+8], accum2; + st.global.volatile.f32 [dummy_sink_f32+12], accum3; + ret; +} + + +/** + * The instruction syntax for Warp-Group asynchronous instructions is very + * different, as at least one of the operand matrices has to be in shared + * memory (not registers). It's documented as in 2 variants: + * + * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 + * d, a-desc, b-desc, scale-d, imm-scale-a, imm-scale-b; + * wgmma.mma_async.sync.aligned.shape.dtype.tf32.tf32 + * d, a, b-desc, scale-d, imm-scale-a, imm-scale-b; + * + * There is no "C" matrix involved at all, we are computing `D = A * B + D`. + * The A and B matrix descriptors are the properties of the matrix in shared + * memory. It is a 64-bit value contained with the following layout: + * + * - 14 bits [0; 13]: start address + * - 14 bits [16; 29]: leading dimension byte offset + * - 14 bits [32; 45]: stride dimension byte offset + * - 3 bits [49; 51]: matrix base offset, valid only for "swizzling" + * - 2 bits [62; 63]: "swizzling" mode + * + * Swizzling defines the order of the elements and can have 4 possible values: + * + * 0: no "swizzling" at all + * 1: a 128-byte "swizzle" with a 1024 byte offset of a repeating pattern + * 2: a 64-byte "swizzle" with a 512 byte offset of a repeating pattern + * 3: a 32-byte "swizzle" with a 256 byte offset of a repeating pattern + * + * The list of supported shapes is exhausting: + * + * .m64n8k8, .m64n16k8, .m64n24k8, .m64n32k8, + * .m64n40k8, .m64n48k8, .m64n56k8, .m64n64k8, + * .m64n72k8, .m64n80k8, .m64n88k8, .m64n96k8, + * .m64n104k8, .m64n112k8, .m64n120k8, .m64n128k8, + * .m64n136k8, .m64n144k8, .m64n152k8, .m64n160k8, + * .m64n168k8, .m64n176k8, .m64n184k8, .m64n192k8, + * .m64n200k8, .m64n208k8, .m64n216k8, .m64n224k8, + * .m64n232k8, .m64n240k8, .m64n248k8, .m64n256k8 + * + * The `scale` parameters can be used to either negate the inputs, or disable + * additive bias accumulation in the output. + */ +.visible .entry tops_tf32f32_sm90tc_m64n16k8_loop128_ptx_kernel() +{ + // Accumulator registers used for both input and output of this MMA + .reg .f32 accum<8>; + + // Descriptors for matrix A and matrix B operands + .reg .b64 desc_a, desc_b; + + // General-purpose registers for loop control + .reg .b32 loop_counter, loop_limit; + + // Predicate register for conditional branching (loop exit) + .reg .pred exit_predicate; + + // Set up loop counter and loop limit + mov.u32 loop_counter, 0; + mov.u32 loop_limit, 128; + + // Zero-initialize the accumulator registers + mov.f32 accum0, 0.0; + mov.f32 accum1, 0.0; + mov.f32 accum2, 0.0; + mov.f32 accum3, 0.0; + mov.f32 accum4, 0.0; + mov.f32 accum5, 0.0; + mov.f32 accum6, 0.0; + mov.f32 accum7, 0.0; + + // Initialize matrix descriptors with arbitrary placeholder values + mov.u64 desc_a, 0x0000000000000000; + mov.u64 desc_b, 0x0000000000000000; + + // Enforce the ordered for Warp-Group instructions + wgmma.fence.sync.aligned; + + // The main loop will repeat for 128 iterations +loop_start: + setp.ge.u32 exit_predicate, loop_counter, loop_limit; + @exit_predicate bra loop_exit; + + wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 + { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7 }, + desc_a, + desc_b, + 0, -1, -1; + + // Increment the loop counter + add.u32 loop_counter, loop_counter, 1; + + // Branch back to the beginning of the loop + bra loop_start; + +loop_exit: + // Commit all prior uncommitted operations to the group and wait! + wgmma.commit_group.sync.aligned; + wgmma.wait_group.sync.aligned 0; + + // Use volatile stores to force the accumulator values to be written out. + // This dummy write (to a global variable) makes the work observable and + // prevents the multiplication pipeline from being optimized out. + st.global.volatile.f32 [dummy_sink_f32], accum0; + st.global.volatile.f32 [dummy_sink_f32+4], accum1; + st.global.volatile.f32 [dummy_sink_f32+8], accum2; + st.global.volatile.f32 [dummy_sink_f32+12], accum3; + ret; +} + +/** + * This results in massive performance gains on Hopper: + * - 16x16x8 MMA computed by individual warps: 74 T + * - 64x16x8 WGMMA computed by four warps together: 300 T + * + * Will it get even better with larger matrices if we scale the second + * dimension from 16 to 256? It would require 128 accumulators. + */ + +.visible .entry tops_tf32f32_sm90tc_m64n256k8_loop128_ptx_kernel() +{ + // Accumulator registers used for both input and output of this MMA + .reg .f32 accum<128>; + + // Descriptors for matrix A and matrix B operands + .reg .b64 desc_a, desc_b; + + // General-purpose registers for loop control + .reg .b32 loop_counter, loop_limit; + + // Predicate register for conditional branching (loop exit) + .reg .pred exit_predicate; + + // Set up loop counter and loop limit to fill accumulators + mov.u32 loop_counter, 0; + mov.u32 loop_limit, 128; + + // Zero-initialize the accumulator registers: + mov.f32 accum0, 0.0; mov.f32 accum1, 0.0; mov.f32 accum2, 0.0; mov.f32 accum3, 0.0; + mov.f32 accum4, 0.0; mov.f32 accum5, 0.0; mov.f32 accum6, 0.0; mov.f32 accum7, 0.0; + mov.f32 accum8, 0.0; mov.f32 accum9, 0.0; mov.f32 accum10, 0.0; mov.f32 accum11, 0.0; + mov.f32 accum12, 0.0; mov.f32 accum13, 0.0; mov.f32 accum14, 0.0; mov.f32 accum15, 0.0; + mov.f32 accum16, 0.0; mov.f32 accum17, 0.0; mov.f32 accum18, 0.0; mov.f32 accum19, 0.0; + mov.f32 accum20, 0.0; mov.f32 accum21, 0.0; mov.f32 accum22, 0.0; mov.f32 accum23, 0.0; + mov.f32 accum24, 0.0; mov.f32 accum25, 0.0; mov.f32 accum26, 0.0; mov.f32 accum27, 0.0; + mov.f32 accum28, 0.0; mov.f32 accum29, 0.0; mov.f32 accum30, 0.0; mov.f32 accum31, 0.0; + mov.f32 accum32, 0.0; mov.f32 accum33, 0.0; mov.f32 accum34, 0.0; mov.f32 accum35, 0.0; + mov.f32 accum36, 0.0; mov.f32 accum37, 0.0; mov.f32 accum38, 0.0; mov.f32 accum39, 0.0; + mov.f32 accum40, 0.0; mov.f32 accum41, 0.0; mov.f32 accum42, 0.0; mov.f32 accum43, 0.0; + mov.f32 accum44, 0.0; mov.f32 accum45, 0.0; mov.f32 accum46, 0.0; mov.f32 accum47, 0.0; + mov.f32 accum48, 0.0; mov.f32 accum49, 0.0; mov.f32 accum50, 0.0; mov.f32 accum51, 0.0; + mov.f32 accum52, 0.0; mov.f32 accum53, 0.0; mov.f32 accum54, 0.0; mov.f32 accum55, 0.0; + mov.f32 accum56, 0.0; mov.f32 accum57, 0.0; mov.f32 accum58, 0.0; mov.f32 accum59, 0.0; + mov.f32 accum60, 0.0; mov.f32 accum61, 0.0; mov.f32 accum62, 0.0; mov.f32 accum63, 0.0; + mov.f32 accum64, 0.0; mov.f32 accum65, 0.0; mov.f32 accum66, 0.0; mov.f32 accum67, 0.0; + mov.f32 accum68, 0.0; mov.f32 accum69, 0.0; mov.f32 accum70, 0.0; mov.f32 accum71, 0.0; + mov.f32 accum72, 0.0; mov.f32 accum73, 0.0; mov.f32 accum74, 0.0; mov.f32 accum75, 0.0; + mov.f32 accum76, 0.0; mov.f32 accum77, 0.0; mov.f32 accum78, 0.0; mov.f32 accum79, 0.0; + mov.f32 accum80, 0.0; mov.f32 accum81, 0.0; mov.f32 accum82, 0.0; mov.f32 accum83, 0.0; + mov.f32 accum84, 0.0; mov.f32 accum85, 0.0; mov.f32 accum86, 0.0; mov.f32 accum87, 0.0; + mov.f32 accum88, 0.0; mov.f32 accum89, 0.0; mov.f32 accum90, 0.0; mov.f32 accum91, 0.0; + mov.f32 accum92, 0.0; mov.f32 accum93, 0.0; mov.f32 accum94, 0.0; mov.f32 accum95, 0.0; + mov.f32 accum96, 0.0; mov.f32 accum97, 0.0; mov.f32 accum98, 0.0; mov.f32 accum99, 0.0; + mov.f32 accum100, 0.0; mov.f32 accum101, 0.0; mov.f32 accum102, 0.0; mov.f32 accum103, 0.0; + mov.f32 accum104, 0.0; mov.f32 accum105, 0.0; mov.f32 accum106, 0.0; mov.f32 accum107, 0.0; + mov.f32 accum108, 0.0; mov.f32 accum109, 0.0; mov.f32 accum110, 0.0; mov.f32 accum111, 0.0; + mov.f32 accum112, 0.0; mov.f32 accum113, 0.0; mov.f32 accum114, 0.0; mov.f32 accum115, 0.0; + mov.f32 accum116, 0.0; mov.f32 accum117, 0.0; mov.f32 accum118, 0.0; mov.f32 accum119, 0.0; + mov.f32 accum120, 0.0; mov.f32 accum121, 0.0; mov.f32 accum122, 0.0; mov.f32 accum123, 0.0; + mov.f32 accum124, 0.0; mov.f32 accum125, 0.0; mov.f32 accum126, 0.0; mov.f32 accum127, 0.0; + + // Initialize matrix descriptors with arbitrary placeholder values + mov.u64 desc_a, 0x0000000000000000; + mov.u64 desc_b, 0x0000000000000000; + + // Enforce the ordered for Warp-Group instructions + wgmma.fence.sync.aligned; + + // The main loop will repeat for 128 iterations +loop_start: + setp.ge.u32 exit_predicate, loop_counter, loop_limit; + @exit_predicate bra loop_exit; + + wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 + { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7, + accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15, + accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23, + accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31, + accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39, + accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47, + accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55, + accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63, + accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71, + accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79, + accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87, + accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95, + accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103, + accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111, + accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119, + accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 }, + desc_a, + desc_b, + 0, -1, -1; + + // Increment the loop counter + add.u32 loop_counter, loop_counter, 1; + + // Branch back to the beginning of the loop + bra loop_start; + +loop_exit: + // Commit all prior uncommitted operations to the group and wait! + wgmma.commit_group.sync.aligned; + wgmma.wait_group.sync.aligned 0; + + // Use volatile stores to force the accumulator values to be written out. + // This dummy write (to a global variable) makes the work observable and + // prevents the multiplication pipeline from being optimized out. + st.global.volatile.f32 [dummy_sink_f32], accum0; + st.global.volatile.f32 [dummy_sink_f32+4], accum1; + st.global.volatile.f32 [dummy_sink_f32+8], accum2; + st.global.volatile.f32 [dummy_sink_f32+12], accum3; + ret; +} + +/** + * This results in massive performance gains on Hopper: + * - 16x16x8 MMA computed by individual warps: 74 T + * - 64x16x8 WGMMA computed by four warps together: 300 T + * - 64x256x8 WGMMA computed by four warps together: 4.7 P ?! + * + * There are also "structured-sparse" variants of those instructions, in case + * half of our entries are zeros! Those, however, simply expand the last + * dimension by 2x, making the instructions no more usable for small matrices. + */ + +.visible .entry tops_b1i32and_sm90tc_m64n256k256_loop128_ptx_kernel() +{ + // Accumulator registers used for both input and output of the MMA operation + .reg .s32 accum<128>; + + // Descriptors for matrix A and matrix B operands (in shared memory) + .reg .b64 desc_a, desc_b; + + // General-purpose registers for loop control + .reg .b32 loop_counter, loop_limit; + + // Predicate registers for conditional branching (loop exit) and scale flag + .reg .pred exit_predicate, scale_d; + + // Set up loop counter and loop limit + mov.u32 loop_counter, 0; + mov.u32 loop_limit, 128; + + // Zero-initialize the accumulators, as registers may contain noise + mov.s32 accum0, 0; mov.s32 accum1, 0; mov.s32 accum2, 0; mov.s32 accum3, 0; + mov.s32 accum4, 0; mov.s32 accum5, 0; mov.s32 accum6, 0; mov.s32 accum7, 0; + mov.s32 accum8, 0; mov.s32 accum9, 0; mov.s32 accum10, 0; mov.s32 accum11, 0; + mov.s32 accum12, 0; mov.s32 accum13, 0; mov.s32 accum14, 0; mov.s32 accum15, 0; + mov.s32 accum16, 0; mov.s32 accum17, 0; mov.s32 accum18, 0; mov.s32 accum19, 0; + mov.s32 accum20, 0; mov.s32 accum21, 0; mov.s32 accum22, 0; mov.s32 accum23, 0; + mov.s32 accum24, 0; mov.s32 accum25, 0; mov.s32 accum26, 0; mov.s32 accum27, 0; + mov.s32 accum28, 0; mov.s32 accum29, 0; mov.s32 accum30, 0; mov.s32 accum31, 0; + mov.s32 accum32, 0; mov.s32 accum33, 0; mov.s32 accum34, 0; mov.s32 accum35, 0; + mov.s32 accum36, 0; mov.s32 accum37, 0; mov.s32 accum38, 0; mov.s32 accum39, 0; + mov.s32 accum40, 0; mov.s32 accum41, 0; mov.s32 accum42, 0; mov.s32 accum43, 0; + mov.s32 accum44, 0; mov.s32 accum45, 0; mov.s32 accum46, 0; mov.s32 accum47, 0; + mov.s32 accum48, 0; mov.s32 accum49, 0; mov.s32 accum50, 0; mov.s32 accum51, 0; + mov.s32 accum52, 0; mov.s32 accum53, 0; mov.s32 accum54, 0; mov.s32 accum55, 0; + mov.s32 accum56, 0; mov.s32 accum57, 0; mov.s32 accum58, 0; mov.s32 accum59, 0; + mov.s32 accum60, 0; mov.s32 accum61, 0; mov.s32 accum62, 0; mov.s32 accum63, 0; + mov.s32 accum64, 0; mov.s32 accum65, 0; mov.s32 accum66, 0; mov.s32 accum67, 0; + mov.s32 accum68, 0; mov.s32 accum69, 0; mov.s32 accum70, 0; mov.s32 accum71, 0; + mov.s32 accum72, 0; mov.s32 accum73, 0; mov.s32 accum74, 0; mov.s32 accum75, 0; + mov.s32 accum76, 0; mov.s32 accum77, 0; mov.s32 accum78, 0; mov.s32 accum79, 0; + mov.s32 accum80, 0; mov.s32 accum81, 0; mov.s32 accum82, 0; mov.s32 accum83, 0; + mov.s32 accum84, 0; mov.s32 accum85, 0; mov.s32 accum86, 0; mov.s32 accum87, 0; + mov.s32 accum88, 0; mov.s32 accum89, 0; mov.s32 accum90, 0; mov.s32 accum91, 0; + mov.s32 accum92, 0; mov.s32 accum93, 0; mov.s32 accum94, 0; mov.s32 accum95, 0; + mov.s32 accum96, 0; mov.s32 accum97, 0; mov.s32 accum98, 0; mov.s32 accum99, 0; + mov.s32 accum100, 0; mov.s32 accum101, 0; mov.s32 accum102, 0; mov.s32 accum103, 0; + mov.s32 accum104, 0; mov.s32 accum105, 0; mov.s32 accum106, 0; mov.s32 accum107, 0; + mov.s32 accum108, 0; mov.s32 accum109, 0; mov.s32 accum110, 0; mov.s32 accum111, 0; + mov.s32 accum112, 0; mov.s32 accum113, 0; mov.s32 accum114, 0; mov.s32 accum115, 0; + mov.s32 accum116, 0; mov.s32 accum117, 0; mov.s32 accum118, 0; mov.s32 accum119, 0; + mov.s32 accum120, 0; mov.s32 accum121, 0; mov.s32 accum122, 0; mov.s32 accum123, 0; + mov.s32 accum124, 0; mov.s32 accum125, 0; mov.s32 accum126, 0; mov.s32 accum127, 0; + + // Initialize matrix descriptors with arbitrary placeholder values. + // In practice, these would be set to point to shared-memory regions containing your matrices. + mov.u64 desc_a, 0x0000000000000000; + mov.u64 desc_b, 0x0000000000000000; + + // Initialize scale flag (controls operand scaling or additive bias behavior) + mov.pred scale_d, 1; + + // Enforce the ordered for Warp-Group instructions + wgmma.fence.sync.aligned; + + // The main loop will repeat for 128 iterations +loop_start: + setp.ge.u32 exit_predicate, loop_counter, loop_limit; + @exit_predicate bra loop_exit; + + wgmma.mma_async.sync.aligned.m64n256k256.s32.b1.b1.and.popc + { accum0, accum1, accum2, accum3, accum4, accum5, accum6, accum7, + accum8, accum9, accum10, accum11, accum12, accum13, accum14, accum15, + accum16, accum17, accum18, accum19, accum20, accum21, accum22, accum23, + accum24, accum25, accum26, accum27, accum28, accum29, accum30, accum31, + accum32, accum33, accum34, accum35, accum36, accum37, accum38, accum39, + accum40, accum41, accum42, accum43, accum44, accum45, accum46, accum47, + accum48, accum49, accum50, accum51, accum52, accum53, accum54, accum55, + accum56, accum57, accum58, accum59, accum60, accum61, accum62, accum63, + accum64, accum65, accum66, accum67, accum68, accum69, accum70, accum71, + accum72, accum73, accum74, accum75, accum76, accum77, accum78, accum79, + accum80, accum81, accum82, accum83, accum84, accum85, accum86, accum87, + accum88, accum89, accum90, accum91, accum92, accum93, accum94, accum95, + accum96, accum97, accum98, accum99, accum100, accum101, accum102, accum103, + accum104, accum105, accum106, accum107, accum108, accum109, accum110, accum111, + accum112, accum113, accum114, accum115, accum116, accum117, accum118, accum119, + accum120, accum121, accum122, accum123, accum124, accum125, accum126, accum127 }, + desc_a, + desc_b, + scale_d; + + // Increment the loop counter + add.u32 loop_counter, loop_counter, 1; + + // Branch back to the beginning of the loop + bra loop_start; + +loop_exit: + // Commit all prior uncommitted operations to the group and wait! + wgmma.commit_group.sync.aligned; + wgmma.wait_group.sync.aligned 0; + + // Use volatile stores to force the accumulator values to be written out. + // This dummy write (to a global variable) makes the work observable and + // prevents the multiplication pipeline from being optimized out. + st.global.volatile.s32 [dummy_sink_s32], accum0; + st.global.volatile.s32 [dummy_sink_s32+4], accum1; + st.global.volatile.s32 [dummy_sink_s32+8], accum2; + st.global.volatile.s32 [dummy_sink_s32+12], accum3; + ret; +}