Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw error when using attn_in with grouped query attention #810

Conversation

degenfabian
Copy link
Contributor

@degenfabian degenfabian commented Dec 11, 2024

Description

When using attn_in with models that use GroupedQueryAttention, TransformerLens crashes because use_attn_in does not account for the different number of query and key/value heads when using GQA. For models with GQA use_split_qkv_input should be used instead, because it implements hooks for query, key and value heads and therefore can account for the different number of heads for each of them. This PR implements a more meaningful error message that informs the user to use split_qkv_input when working with models with GQA instead of use_attn_in.

This PR is not linked to a specific issue.

After adding a test case, it failed because of a beartype error that stated that rotary_base needs to be an integer instead of a float. I adjusted this accordingly in the configuration of google/gemma-2b

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@degenfabian degenfabian changed the title Throw error when using attn in with grouped query attention Throw error when using attn_in with grouped query attention Dec 11, 2024
@bryce13950 bryce13950 merged commit d0d0750 into TransformerLensOrg:dev Dec 28, 2024
13 checks passed
@degenfabian degenfabian deleted the throw_error_when_using_attn_in_with_grouped_query_attention branch December 28, 2024 16:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants