Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#92 from ploshkin/rm-shape-asserts
Browse files Browse the repository at this point in the history
Fix slicing dimensions in rotary embeddings
  • Loading branch information
tridao authored Dec 17, 2022
2 parents b78f5a3 + ee8984d commit dc24c22
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions flash_attn/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ def forward(ctx, x, cos, sin, inplace=False):
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert cos.shape == (rotary_seqlen, rotary_dim // 2)
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
out = torch.empty_like(x) if not inplace else x
o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False)
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
Expand All @@ -66,8 +65,8 @@ def backward(ctx, do):
do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
dx = torch.empty_like(do) if not inplace else do
dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True)
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None
Expand All @@ -92,14 +91,13 @@ def forward(ctx, qkv, cos, sin):
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert cos.shape == (seqlen, rotary_dim // 2)
assert sin.shape == (seqlen, rotary_dim // 2)
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False)
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False)
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False)
ctx.save_for_backward(cos, sin)
return qkv

Expand All @@ -110,11 +108,11 @@ def backward(ctx, dqkv):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True)
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True)
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
return dqkv, None, None


Expand Down

0 comments on commit dc24c22

Please sign in to comment.