Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improper ComputeCapability check for cudnn_dot_product_attention. #22546

Closed
MasterSkepticista opened this issue Jul 21, 2024 · 12 comments
Closed
Labels
bug Something isn't working

Comments

@MasterSkepticista
Copy link

Description

We check for CC here:
/~https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L990

But the check (L316) fails if compute_cap is an integer between (80, 90).
/~https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L315-L317

The intention is to allow all GPUs with compute capability within the range.

if compute_cap not in cc:  # (86 not in (80, 90)) will fail. It shouldn't.
  raise RuntimeError(...)

I disabled the CC check on my platform (it is 86) and cudnn gives speedup as expected.
Correct way could be:

assert len(cc) == 2, "Provide a (low, high) range"
lo, hi = cc
if compute_cap not in range(lo, hi + 1):
  raise RuntimeError(...)

Happy to do a PR.

System info (python version, jaxlib version, accelerator, etc.)

System information

jax:    0.4.31.dev20240720                                                                                                                                                                                                         
jaxlib: 0.4.30                                                                                                                                                                                                                     
numpy:  1.26.4                                                                                                                                                                                                                     
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]                                                                                                                                                                         
jax.devices (1 total, 1 local): [cuda(id=0)]                                                                                                                                                                                       
process_count: 1                                                                                                                                                                                                                   
platform: uname_result(system='Linux', node='sagittarius', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May  7 09:00:52 UTC 2', machine='x86_64')

<truncated>
GPU: RTX-A6000 (Ampere)
Compute Capability: 8.6
Driver: 555.42.06
CUDA: 12.5
cuDNN: 9.2
@MasterSkepticista MasterSkepticista added the bug Something isn't working label Jul 21, 2024
@monatis
Copy link

monatis commented Jul 21, 2024

Hi @MasterSkepticista, I find it quite intersting. To my knowledge, CuDNN's FMHA implementation is only supported on CC 8.0 and 9.0. (I confirmed it with a developer from NVIDIA.) Was the support for CC 8.6 added in a new version?
How much speedup did you gain with this, and how did you test it?

@MasterSkepticista
Copy link
Author

Flash attention is supported on all Ampere and Hopper GPUs, I think. PyTorch version also works faster on CC 86 (it does not fail their internal check).

Benchmark:

b, t, h, d = 1, 1024, 12, 64
key = jax.random.key(42)
q = k = v = jax.random.uniform(key, (b, t, h, d), jnp.bfloat16)
xla_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="xla"))
flash_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="cudnn"))

# Warmup.
out_xla = xla_attn(q, k, v)
out_flash = flash_attn(q, k, v)
assert jnp.allclose(out_xla, out_flash, atol=1e-2)
print(jnp.abs(out_xla - out_flash).max())
# 0.00390625 (bfloat16 epsilon?)

%timeit -n 100 xla_attn(q, k, v).block_until_ready()
# 422 μs ± 5.16 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit -n 100 flash_attn(q, k, v).block_until_ready()
# 101 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@hawkinsp
Copy link
Collaborator

From @kaixih, my understanding is that the check is correct. 8.6 is not supported.

@MasterSkepticista
Copy link
Author

MasterSkepticista commented Jul 23, 2024

Hi @hawkinsp, I believe sm86 and sm89 are supported under certain size constraints (for ref: Dao-AILab/flash-attention#138 (comment), sdp_utils.cpp#L268). Flash attention on sm86 works in PyTorch without any warnings.

Perhaps we can put a check on num_heads or head_dim in the code and allow both sm86 and sm89 to benefit from this speedup.

I would be happy to contribute with a PR.

@monatis
Copy link

monatis commented Jul 23, 2024

I reproduced the speedup you posted, but the CuDNN implementation is independent from DaoAILab's implementation. I also confirmed it from an NVIDIA engineer, and they said it's very unlikely to support sm86 and sm89 in the future because compute capabilities with a non-0 minor version have a smaller shared memory and CuDNN's implementation is optimized for larger shared memories in sm80 and sm90. There should be something else happening there.

@monatis
Copy link

monatis commented Jul 23, 2024

I figured out that Torch also indicates that CuDNN FMHA implementation is supported only on sm86 and sm89 on line 465. However, they also pack DaoAILab's implementation as a part of Pytroch and have a fallback mechanism like CuDNN FMHA -> DaoAILab's flash attention v2 -> memory-efficient attention implementation.

@MasterSkepticista
Copy link
Author

I can check which kernel runs in PyTorch on sm86.

But I think there is merit in allowing JAX cudnn SDPA to run on sm86 and sm89 with a warning. Worst case the smaller shared memory leads to cache misses (which actually doesn't seem to be the case from the microbenchmarks and the fact that it converges to same loss values as xla implementation).

We could also limit the head_dim and num_heads, as PyTorch does here for sm86 and sm89.

What do you think?

@monatis
Copy link

monatis commented Jul 23, 2024

I checked the lowering for flash_attn with flash_attn.lower(q, k, v).compile().as_text() and the custom call target is really __cudnn$fmhaSoftmax. Pretty interesting.

@kaixih
Copy link
Contributor

kaixih commented Jul 23, 2024

@Cjkkkk Can you comment on the version restrictions of cudnn flash attn?

@Cjkkkk
Copy link
Contributor

Cjkkkk commented Jul 23, 2024

Thanks for bringing this up. I think the constraint of sm86/sm89 is for non flash version attention which is removed now from jax cudnn SDPA API. Confirming this with Nvidia cudnn team now to see if we can relax the constraint for sm86/sm89 for flash attention.

@Cjkkkk
Copy link
Contributor

Cjkkkk commented Jul 23, 2024

Created a pr to include sm86/sm89 as well: #22605

@MasterSkepticista
Copy link
Author

Will close, now that #22605 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants