Skip to content

Commit

Permalink
remove einsum usage from create_alibi_bias function in AbstractAttent…
Browse files Browse the repository at this point in the history
…ion (#781)

Co-authored-by: Bryce Meyer <bryce13950@gmail.com>
Co-authored-by: Fabian Degen <fabian.degen@mytum.de>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent b7c4dbd commit 623407f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,11 @@ def create_alibi_bias(
n_heads, device
)

# The ALiBi bias is then m * slope_matrix
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)
# Add singleton dimensions to make shapes compatible for broadcasting:
slope = einops.rearrange(slope, "query key -> 1 query key")
multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1")

# Element-wise multiplication of the slope and multipliers
alibi_bias = multipliers * slope

return alibi_bias

0 comments on commit 623407f

Please sign in to comment.