-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improper ComputeCapability
check for cudnn_dot_product_attention
.
#22546
Comments
Hi @MasterSkepticista, I find it quite intersting. To my knowledge, CuDNN's FMHA implementation is only supported on CC 8.0 and 9.0. (I confirmed it with a developer from NVIDIA.) Was the support for CC 8.6 added in a new version? |
Flash attention is supported on all Ampere and Hopper GPUs, I think. PyTorch version also works faster on CC 86 (it does not fail their internal check). Benchmark: b, t, h, d = 1, 1024, 12, 64
key = jax.random.key(42)
q = k = v = jax.random.uniform(key, (b, t, h, d), jnp.bfloat16)
xla_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="xla"))
flash_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="cudnn"))
# Warmup.
out_xla = xla_attn(q, k, v)
out_flash = flash_attn(q, k, v)
assert jnp.allclose(out_xla, out_flash, atol=1e-2)
print(jnp.abs(out_xla - out_flash).max())
# 0.00390625 (bfloat16 epsilon?)
%timeit -n 100 xla_attn(q, k, v).block_until_ready()
# 422 μs ± 5.16 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit -n 100 flash_attn(q, k, v).block_until_ready()
# 101 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
From @kaixih, my understanding is that the check is correct. 8.6 is not supported. |
Hi @hawkinsp, I believe sm86 and sm89 are supported under certain size constraints (for ref: Dao-AILab/flash-attention#138 (comment), sdp_utils.cpp#L268). Flash attention on Perhaps we can put a check on I would be happy to contribute with a PR. |
I reproduced the speedup you posted, but the CuDNN implementation is independent from DaoAILab's implementation. I also confirmed it from an NVIDIA engineer, and they said it's very unlikely to support sm86 and sm89 in the future because compute capabilities with a non-0 minor version have a smaller shared memory and CuDNN's implementation is optimized for larger shared memories in sm80 and sm90. There should be something else happening there. |
I figured out that Torch also indicates that CuDNN FMHA implementation is supported only on sm86 and sm89 on line 465. However, they also pack DaoAILab's implementation as a part of Pytroch and have a fallback mechanism like CuDNN FMHA -> DaoAILab's flash attention v2 -> memory-efficient attention implementation. |
I can check which kernel runs in PyTorch on sm86. But I think there is merit in allowing JAX We could also limit the What do you think? |
I checked the lowering for |
@Cjkkkk Can you comment on the version restrictions of cudnn flash attn? |
Thanks for bringing this up. I think the constraint of sm86/sm89 is for non flash version attention which is removed now from jax cudnn SDPA API. Confirming this with Nvidia cudnn team now to see if we can relax the constraint for sm86/sm89 for flash attention. |
Created a pr to include sm86/sm89 as well: #22605 |
Will close, now that #22605 is merged. |
Description
We check for CC here:
/~https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L990
But the check (L316) fails if
compute_cap
is an integer between(80, 90)
./~https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L315-L317
The intention is to allow all GPUs with compute capability within the range.
I disabled the CC check on my platform (it is 86) and
cudnn
gives speedup as expected.Correct way could be:
Happy to do a PR.
System info (python version, jaxlib version, accelerator, etc.)
System information
The text was updated successfully, but these errors were encountered: