Skip to content

Commit

Permalink
Update the apply_interpolation function
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 2, 2024
1 parent 50b93ae commit 0f4ef37
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,13 @@ def forward(self, input_ids):

return embeddings

def apply_interpolation(self, pos_embed, lambda_factors, n_hat):
def apply_interpolation(self, pos_embed, context_length):
"""Apply non-uniform interpolation to position embeddings."""
return non_uniform_interpolation(
pos_embed, self.extension_ratio, lambda_factors, n_hat
pos_embed,
self.extension_ratio,
self.lambda_factors[context_length],
self.n_hat[context_length],
)

def extend_context(
Expand Down

0 comments on commit 0f4ef37

Please sign in to comment.