-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve performance via batched-matmul and fused multiplies #7
Comments
Really nice implementation @Birch-san. Thanks for your detailed description. I'd appreciate it if you could make the modifications described above in this repo too and create a PR. Is it possible for you? |
I'm looking at getting this implementation added to AUTOMATIC1111/stable-diffusion-webui but neither this package nor Birch-san's modified implementation has a license. Unfortunately if I simply use this package unmodified I won't get these speed optimizations nor Mac support, both of which are needed. So I'd like to ask both of you: |
Yup, I'm okay with releasing my contribution under a permissive license. There are some bits of @AminRezaei0x443's code remaining, so good to find out their licensing desires. But I believe my implementation doesn't contain any lines of code from the other contributors to that repository (as I deleted support for masking and callbacks, and the device argument fixes). |
Oh, you might want to read the latest commits I added: I reduced the number of times the key needed to be transposed. This complicates the API slightly. Has a chance of improving performance though. |
@brkirch Thanks for your interest and I'm glad this library is useful for you. This project is licensed under |
Thanks! I've added the license to the PR. @Birch-san I did experiment with Birch-san/diffusers@9dc6822 but could only get worse performance. I've been using kulinseth/pytorch instead of PyTorch nightly builds though, as it gets significantly better performance (usually ~25% faster) for MPS than the nightly builds in my testing. Right now most AUTOMATIC1111 users will be on PyTorch 1.12.1 as it is the latest broken for MPS, but due to the significant MPS improvements in PyTorch I'm going to recommend that users consider switching to kulinseth/pytorch for inference and only use 1.12.1 for training (which unfortunately still doesn't work for MPS in any newer PyTorch). |
okay, that's interesting. I'm skeptical. I measured the impact of pre-transpose on unchunked attention a while back, and it didn't seem to change the perf for better or worse. I haven't measured the difference on chunked attention, but we'd have to agree on chunk sizes to do that.
really? I did some TI training a few weeks ago using commit from Dec 23:
wow, I wonder which optimization that is. there's a potentially-dramatic optimization that landed in pytorch master yesterday, which probably was available in kulinseth's branch before that: but there could be downsides to early-adoption. there's stuff that may not be ready: |
okay yeah, got 12% faster when I updated to latest master. might be that |
Yeah it is a bit weird. Originally I got significantly worse performance with AttnBlock (several times slower), so as a compromise I tried making the transpose optional (added a
With torch 2.0.0.dev20230106:
It could be that I actually need to fix something, but last time I tried changing the in-place operations that did exist it didn't help. Since you have it working I'll do more testing later to see if I can get it fixed.
Could be things like this, although I see that it finally has a PR for the main branch. Right now even comparing against the latest nightly build I'm still seeing better performance with the other fork.
Memory usage is a big issue either way. PyTorch 1.12.1 already uses a lot more memory than it should, to the point that nightly builds may be worth using simply for lower memory usage. Unfortunately debugging memory usage seems to be even more difficult than performance. |
okay, you're right: kulinseth's branch is faster than master. I measured it as 10% faster (on float16). also, the PR to which the "memory usage" issue refers… has merged to master anyway lol. as for reasons why pre-transpose could cause slowdown for chunked attention… hmm maybe avoiding the early transpose sets you up to fail. maybe it forces you to index into lots of locations that are far away from each other, forcing a big read. |
Many thanks for providing this reference implementation.
I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
Birch-san/diffusers@0437214
Knowing that computing attention via
baddbmm()
+bmm()
can outperform einsum by 18%: I tried to rewrite the algorithm to use those.I compared the speed of my optimized version, against the implementation in this repository.
this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).
compared to the implementation in this repository:
my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).
here's my optimized implementation:
Birch-san/diffusers#1
batched matmuls require a 3D tensor, i.e.
[batch * num_heads, tokens, channels_per_head]
.code that currently integrates agains this repository's
[batch, q_length, num_heads, qk_depth_per_head]
format can migrate those tensors to the[batch * num_heads, q_length, channels_per_head]
format favoured by my implementation like so:the result that's returned, remains in
[batch * num_heads, q_length, qk_depth_per_head]
format, and can be restored to[batch, q_length, num_heads, qk_depth_per_head]
format like so:I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.
EDIT:
I have now added fast-paths for:
kv_chunk_size >= k_tokens
q_chunk_size >= q_tokens
kv_chunk_size >= k_tokens
andq_chunk_size >= q_tokens
q@k.T
matmul requires fewer bytes than a user-provided thresholdThe text was updated successfully, but these errors were encountered: