Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance via batched-matmul and fused multiplies #7

Open
Birch-san opened this issue Dec 27, 2022 · 11 comments
Open

Improve performance via batched-matmul and fused multiplies #7

Birch-san opened this issue Dec 27, 2022 · 11 comments

Comments

@Birch-san
Copy link

Birch-san commented Dec 27, 2022

Many thanks for providing this reference implementation.

I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
Birch-san/diffusers@0437214

Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

I compared the speed of my optimized version, against the implementation in this repository.
this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

compared to the implementation in this repository:
my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

here's my optimized implementation:
Birch-san/diffusers#1

batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

query = query.transpose(1,2).flatten(end_dim=1)
key = key.transpose(1,2).flatten(end_dim=1)
value = value.transpose(1,2).flatten(end_dim=1)

the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

result.unflatten(0, (-1, attn.heads)).transpose(1,2)

I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

EDIT:
I have now added fast-paths for:

  • skipping kv-chunking when kv_chunk_size >= k_tokens
    • this turns the algorithm into "attention slicing"
  • skipping q-chunking when q_chunk_size >= q_tokens
  • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
  • skipping all chunking when the q@k.T matmul requires fewer bytes than a user-provided threshold
@AminRezaei0x443
Copy link
Owner

Really nice implementation @Birch-san. Thanks for your detailed description.

I'd appreciate it if you could make the modifications described above in this repo too and create a PR. Is it possible for you?

@brkirch
Copy link

brkirch commented Jan 6, 2023

I'm looking at getting this implementation added to AUTOMATIC1111/stable-diffusion-webui but neither this package nor Birch-san's modified implementation has a license. Unfortunately if I simply use this package unmodified I won't get these speed optimizations nor Mac support, both of which are needed.

So I'd like to ask both of you:
@Birch-san @AminRezaei0x443 Are okay with this being added to AUTOMATIC1111/stable-diffusion-webui? You would also be credited on the main project page (see the Credits section at the bottom).

@Birch-san
Copy link
Author

Yup, I'm okay with releasing my contribution under a permissive license. There are some bits of @AminRezaei0x443's code remaining, so good to find out their licensing desires. But I believe my implementation doesn't contain any lines of code from the other contributors to that repository (as I deleted support for masking and callbacks, and the device argument fixes).

@Birch-san
Copy link
Author

Oh, you might want to read the latest commits I added:
Birch-san/diffusers#1

I reduced the number of times the key needed to be transposed. This complicates the API slightly. Has a chance of improving performance though.

@AminRezaei0x443
Copy link
Owner

@brkirch Thanks for your interest and I'm glad this library is useful for you. This project is licensed under MIT, I updated the repository too.

@brkirch
Copy link

brkirch commented Jan 6, 2023

Thanks! I've added the license to the PR.

@Birch-san I did experiment with Birch-san/diffusers@9dc6822 but could only get worse performance. I've been using kulinseth/pytorch instead of PyTorch nightly builds though, as it gets significantly better performance (usually ~25% faster) for MPS than the nightly builds in my testing. Right now most AUTOMATIC1111 users will be on PyTorch 1.12.1 as it is the latest broken for MPS, but due to the significant MPS improvements in PyTorch I'm going to recommend that users consider switching to kulinseth/pytorch for inference and only use 1.12.1 for training (which unfortunately still doesn't work for MPS in any newer PyTorch).

@Birch-san
Copy link
Author

could only get worse performance

okay, that's interesting. I'm skeptical. I measured the impact of pre-transpose on unchunked attention a while back, and it didn't seem to change the perf for better or worse. I haven't measured the difference on chunked attention, but we'd have to agree on chunk sizes to do that.
benchmarking is pretty hard on MPS because there's no profiler and because GPU temperature seems to influence the results a lot (you may find that the first test has a bit of an edge).

broken for MPS, […] only use 1.12.1 for training

really? I did some TI training a few weeks ago using commit from Dec 23:
pytorch/pytorch@789b143
the key thing for training is the group_norm fix that landed on Dec 22:
pytorch/pytorch@fd3a726
only workaround I needed was running cumsum on CPU:
pytorch/pytorch#89784 (comment)

kulinseth/pytorch usually ~25% faster

wow, I wonder which optimization that is. there's a potentially-dramatic optimization that landed in pytorch master yesterday, which probably was available in kulinseth's branch before that:
pytorch/pytorch#91114

but there could be downsides to early-adoption. there's stuff that may not be ready:
pytorch/pytorch#90464

@Birch-san
Copy link
Author

okay yeah, got 12% faster when I updated to latest master. might be that torch.linear() optimization.
https://twitter.com/Birchlabs/status/1611507344302608388

@brkirch
Copy link

brkirch commented Jan 7, 2023

okay, that's interesting. I'm skeptical. I measured the impact of pre-transpose on unchunked attention a while back, and it didn't seem to change the perf for better or worse. I haven't measured the difference on chunked attention, but we'd have to agree on chunk sizes to do that. benchmarking is pretty hard on MPS because there's no profiler and because GPU temperature seems to influence the results a lot (you may find that the first test has a bit of an edge).

Yeah it is a bit weird. Originally I got significantly worse performance with AttnBlock (several times slower), so as a compromise I tried making the transpose optional (added a key_needs_transpose argument for efficient_dot_product_attention). Applying the pre-transpose only for cross attention gave ~20% drop in performance for chunked attention, negligable difference for unchunked.

broken for MPS, […] only use 1.12.1 for training

really? I did some TI training a few weeks ago using commit from Dec 23: pytorch/pytorch@789b143 the key thing for training is the group_norm fix that landed on Dec 22: pytorch/pytorch@fd3a726 only workaround I needed was running cumsum on CPU: pytorch/pytorch#89784 (comment)

With torch 2.0.0.dev20230106:

Traceback (most recent call last):
  File "/Users/brkirch/stable-diffusion-webui/modules/textual_inversion/textual_inversion.py", line 395, in train_embedding
    scaler.scale(loss).backward()
  File "/Users/brkirch/stable-diffusion-webui/venv-torch-2.0-alpha/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/Users/brkirch/stable-diffusion-webui/venv-torch-2.0-alpha/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [MPSFloatType [1, 77, 768]], which is output 0 of MulBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

It could be that I actually need to fix something, but last time I tried changing the in-place operations that did exist it didn't help. Since you have it working I'll do more testing later to see if I can get it fixed.

kulinseth/pytorch usually ~25% faster

wow, I wonder which optimization that is. there's a potentially-dramatic optimization that landed in pytorch master yesterday, which probably was available in kulinseth's branch before that: pytorch/pytorch#91114

Could be things like this, although I see that it finally has a PR for the main branch. Right now even comparing against the latest nightly build I'm still seeing better performance with the other fork.

but there could be downsides to early-adoption. there's stuff that may not be ready: pytorch/pytorch#90464

Memory usage is a big issue either way. PyTorch 1.12.1 already uses a lot more memory than it should, to the point that nightly builds may be worth using simply for lower memory usage. Unfortunately debugging memory usage seems to be even more difficult than performance.

@Birch-san
Copy link
Author

okay, you're right: kulinseth's branch is faster than master. I measured it as 10% faster (on float16).
https://twitter.com/Birchlabs/status/1611525288642613248

also, the PR to which the "memory usage" issue refers… has merged to master anyway lol.

as for reasons why pre-transpose could cause slowdown for chunked attention… hmm maybe avoiding the early transpose sets you up to fail. maybe it forces you to index into lots of locations that are far away from each other, forcing a big read.
maybe the reason it's neglible for unchunked attention is because you're gonna read the whole thing anyway.
perhaps adding an early transpose is basically free due to the other reads we're doing in that area around that time.
perhaps doing lots of late transposes is basically free because it's transposing a small chunk that you're planning to read all of anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@Birch-san @brkirch @AminRezaei0x443 and others