-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUTLASS] Add NDEBUG option to CUTLASS compile to speed up attention kernel #14798
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
But a weird thing is, if I run the same attention workload via the cutlass example, it shows that the same kernel runs in 1.3 msec, see below (compared to our BYOC result, 2.4 msec).
I've also checked out Triton and Flash attention perf on the same workload by running /~https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L330 and saw perf around 1.2 - 1.3 msec. So I want to believe that ~1.3 msec should be the right result for an attention kernel on this workload. So maybe there is something off in how we use this kernel from our BYOC? I compared the generated code and the cutlass example code but didn't find any difference. There is no difference in the nvcc options that might affect performance other than this |
@spectrometerHBH has same observation. Good catch! |
I can confirm there's performance difference, the profiler also shows different number of instructions are executed, though the it's indeed the same kernel |
I tried updating the cutlass submodule revision, but it didn't help. |
it turns out that cutlass profiler has |
wow then Triton and flash attention kernels may indeed be faster than the cutlass one, given that the triton implementation is definitely not doing causal mask optimization. |
triton kernel is also causal attention, /~https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#LL49C57-L49C57 it's not doing a full Q*K.T, so it has less computation |
Interesting! Yeah I didn't understand this loop bound |
I found that adding this option brings non-trivial perf improvement to the attention kernel (2.4 vs 2.7 msec for the most heavy workload in SD UNet, see below). This results in a few msec speed up for SD UNet e2e.
Before (nvprof output on
test_attention_offload((2, (4096, 4096), 8, (40, 40), "float16)
)After
@cyx-6 @vinx13