Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Add NDEBUG option to CUTLASS compile to speed up attention kernel #14798

Merged
merged 1 commit into from
May 8, 2023

Conversation

masahi
Copy link
Member

@masahi masahi commented May 7, 2023

I found that adding this option brings non-trivial perf improvement to the attention kernel (2.4 vs 2.7 msec for the most heavy workload in SD UNet, see below). This results in a few msec speed up for SD UNet e2e.

Before (nvprof output on test_attention_offload((2, (4096, 4096), 8, (40, 40), "float16))

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)     GridXYZ         BlockXYZ                       
                              Name                                                 
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  --------------  --------------  ------------------
----------------------------------------------------------------------------------
 100.0        8,325,790          3  2,775,263.3  2,773,023.0  2,768,736  2,784,031      7,889.8    64    8    2    32    4    1  void attention_kernel_batched_impl<AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, (bool)1, (…

After

   Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)     GridXYZ         BlockXYZ                                                     Name                                                
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  --------------  --------------  ----------------------------------------------------------------------------------------------------
 100.0        7,466,320          3  2,488,773.3  2,483,845.0  2,481,606  2,500,869     10,534.8    64    8    2    32    4    1  void attention_kernel_batched_impl<AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, (bool)1, (…

@cyx-6 @vinx13

@tvm-bot
Copy link
Collaborator

tvm-bot commented May 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: cutlass See #10317 for details

Generated by tvm-bot

@masahi masahi force-pushed the cutlass-ndebug branch from 3c8f899 to 3d88e4a Compare May 7, 2023 20:49
@masahi
Copy link
Member Author

masahi commented May 7, 2023

But a weird thing is, if I run the same attention workload via the cutlass example, it shows that the same kernel runs in 1.3 msec, see below (compared to our BYOC result, 2.4 msec).

$ nsys nvprof examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=8 --batch_size=2 --head_size=40 --_head_size_v=40 --seq_length=4096 --seq_length_kv=4096
                                                                                                                                          
CUTLASS Attention:                                                                                                                                    ====================================================                                                                                                  
     {seq length Q, seq length KV, head size, head size V, head number, batch size} = {4096, 4096, 40, 40, 8, 2}.
                                                                           
    Runtime: 1.36964 ms
    GFLOPs: 19897.7                                                                                                                                   
                                                                                                                                                      
Passed                                                                                                                                                

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)     GridXYZ         BlockXYZ                       
                              Name                                                 
--------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  --------------  --------------  ------------------
     ...
     13.8       30,160,277         22  1,370,921.7  1,368,225.0  1,346,113  1,414,209     16,718.9    64    8    2    32    4    1  void attention_kernel_batched_impl<AttentionKernel<cutlass::half_t, cutlass::arch::Sm80, (bool)1, (…

I've also checked out Triton and Flash attention perf on the same workload by running /~https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L330 and saw perf around 1.2 - 1.3 msec. So I want to believe that ~1.3 msec should be the right result for an attention kernel on this workload.

So maybe there is something off in how we use this kernel from our BYOC? I compared the generated code and the cutlass example code but didn't find any difference. There is no difference in the nvcc options that might affect performance other than this NDEBUG stuff (that's how I found about it). Any thoughts? @vinx13 @cyx-6

@junrushao
Copy link
Member

@spectrometerHBH has same observation. Good catch!

@vinx13 vinx13 merged commit 6c689ee into apache:main May 8, 2023
@vinx13
Copy link
Member

vinx13 commented May 8, 2023

I can confirm there's performance difference, the profiler also shows different number of instructions are executed, though the it's indeed the same kernel

@masahi
Copy link
Member Author

masahi commented May 8, 2023

I tried updating the cutlass submodule revision, but it didn't help.

@vinx13
Copy link
Member

vinx13 commented May 9, 2023

it turns out that cutlass profiler has --causal default to true, after adding --causal=false I can get the same numbers

@masahi
Copy link
Member Author

masahi commented May 9, 2023

wow then Triton and flash attention kernels may indeed be faster than the cutlass one, given that the triton implementation is definitely not doing causal mask optimization.

@vinx13
Copy link
Member

vinx13 commented May 9, 2023

triton kernel is also causal attention, /~https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#LL49C57-L49C57 it's not doing a full Q*K.T, so it has less computation

@masahi
Copy link
Member Author

masahi commented May 9, 2023

Interesting! Yeah I didn't understand this loop bound for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N) when I was studying this code, now it makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants