Skip to content

Commit

Permalink
Use sub-quadratic attention for AttnBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
brkirch committed Dec 30, 2022
1 parent c13aac0 commit 1824569
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 34 deletions.
1 change: 1 addition & 0 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def apply_optimizations():
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
Expand Down
87 changes: 53 additions & 34 deletions modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@ def set_sub_quad_chunk_threshold():
def sub_quad_attention_forward(self, x, context=None, mask=None):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."

q_chunk_size = 1024
kv_chunk_size_min = chunk_threshold_bytes

h = self.heads

q = self.to_q(x)
Expand All @@ -252,37 +249,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

kv_chunk_size = min(int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

uses_chunking = q_tokens > q_chunk_size or k_tokens > kv_chunk_size

if uses_chunking and (chunk_threshold_bytes is None or qk_matmul_size_bytes > chunk_threshold_bytes):
x = efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
use_checkpoint=self.training,
)
else:
# the big matmul fits into our memory limit; compute via unchunked attention (it's faster)
attention_scores = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
q,
k.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
x = torch.bmm(attention_probs, v)
x = sub_quad_attention(q, k, v, kv_chunk_size_min=chunk_threshold_bytes, chunk_threshold_bytes=chunk_threshold_bytes, use_checkpoint=self.training)

x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)

Expand All @@ -292,6 +259,58 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):

return x

def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens

return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
use_checkpoint=use_checkpoint,
)

def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))

q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = sub_quad_attention(q, k, v, kv_chunk_size_min=chunk_threshold_bytes, chunk_threshold_bytes=chunk_threshold_bytes, use_checkpoint=self.training)

out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x+out

def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
Expand Down
6 changes: 6 additions & 0 deletions modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def efficient_dot_product_attention(
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
Expand All @@ -146,6 +147,7 @@ def efficient_dot_product_attention(
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
Expand All @@ -154,6 +156,10 @@ def efficient_dot_product_attention(
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5

kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice(
query,
Expand Down

0 comments on commit 1824569

Please sign in to comment.