Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mistral 7 b #443

Merged
merged 27 commits into from
Jan 22, 2024
Merged

Conversation

Felhof
Copy link
Contributor

@Felhof Felhof commented Oct 27, 2023

Description

This PR closes #387 by adding support for Mistral and implementing Grouped Query Attention.

Mistral 7B

The weight's from Huggingface's Mistral-7B-v0.1 can now be loaded into a HookedTransformer using

# these parameters are necessary as Mistral uses RMS norm and not layernorm.
tl_mistral = HookedTransformer.from_pretrained(
    "mistral-7B",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False
)

Note that Mistral is only supported in transformers >= 4.34 and hence Python >= 3.8 is required to use it.
The demo notebook Mistral.ipynb features a comparison of Huggingface's Mistral with the HookedTransformer implementation. I tested it on the same prompts that were used in the Llama demo. The differences in the resulting logits were around 0.01 which is slightly higher than what Llama2 gets. This may be due to a known issue with rotary embeddings which also affects LLama2 and Pythia.

Grouped Query Attention

Mistral utilizes Grouped Query Attention (GPA) which is not used by any other model supported by TransformerLens and had to be implemented. I added a new class GroupedQueryAttention and an abstract base class AbstractAttention which now features the common functionality of both attention classes. The difference between Attention and GroupedQueryAttention is how they handle the key and value projections, as in GPA groups of queries share the same keys and values (see image below). This mostly affects the internal workings of the class. To avoid confusion, to not break existing code interacting with attention, and to make the design of future code easier, public attributes such as W_K and W_V have the same shape for both classes. This is because in the case of GPA, the underlying parameters are hidden behind a property that expands them using torch.repeat_interleave. A GPA block should behave the same as a regular Attention block whose weights are the result of applying torch.repeat_interleave to the GPA block's weight. There is a unit test that confirms this.

grouped_query_attention

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@Felhof Felhof force-pushed the support-mistral-7B branch from 7da7b4f to 2322bd8 Compare October 27, 2023 14:35
Copy link
Collaborator

@neelnanda-io neelnanda-io left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for doing this! I left some small comments, but this isn't a full review (I haven't looked at exactly how you implemented grouped query attention), so I'll leave that for someone else to finish off.

transformer_lens/loading_from_pretrained.py Outdated Show resolved Hide resolved
@alan-cooney alan-cooney self-requested a review October 28, 2023 13:10
@Felhof Felhof force-pushed the support-mistral-7B branch 3 times, most recently from cf3245b to 473da7a Compare November 3, 2023 11:36
@Felhof Felhof force-pushed the support-mistral-7B branch from d35a5a2 to 5b9a3fa Compare November 3, 2023 11:56
@alan-cooney
Copy link
Collaborator

Just to flag I'll review this once the previous PR goes through on attention, as we'll need to adjust for a few conflicts in it

@alan-cooney
Copy link
Collaborator

alan-cooney commented Nov 11, 2023

Hey @Felhof nice work on this, and the abstract attention class approach is great (we're planning on doing this for a bunch of the components).

Would you be able to fix the merge conflicts first and then I can do a full review? There's a few things there as the last PR to go through had some attention component changes.

pyproject.toml Outdated Show resolved Hide resolved
@neelnanda-io
Copy link
Collaborator

Hey! I'm curious what the status of this PR is? A few of my MATS scholars want to use Mistral. Can they just check out this PR?

@ojh31
Copy link
Contributor

ojh31 commented Nov 28, 2023

Hey! I'm curious what the status of this PR is? A few of my MATS scholars want to use Mistral. Can they just check out this PR?

That worked for me, modulo the comment I made above about using the right version of transformers (and the tokenizer being disgusting)

@Felhof Felhof force-pushed the support-mistral-7B branch from 8f0b06f to cff0a86 Compare December 1, 2023 11:03
@abdurraheemali
Copy link

abdurraheemali commented Dec 22, 2023

@alan-cooney @neelnanda-io @Felhof I want to run an activation patching experiment on llama-70b, so I'm going to check out the support-mistral-7b branch and report (in an hour-ish?) whether the grouped query attention implementation works for me or not

Edit: I couldn't get it to work with some other dependencies I had, but I'll try again later this week.

@neelnanda-io
Copy link
Collaborator

@alan-cooney @Felhof What's the status of this PR? I'm not sure which of you it's blocked on. Either way, I hear that people have been able to checkout this branch and get Mistral working, so thanks a lot for the work up to that point!

@Felhof Felhof force-pushed the support-mistral-7B branch from f27d05d to 9ede6f9 Compare January 10, 2024 01:26
@Felhof
Copy link
Contributor Author

Felhof commented Jan 10, 2024

The branch has been working for a while but needs approval to be merged into main :) I made sure it's up-to-date with main again.

@wesg52
Copy link
Contributor

wesg52 commented Jan 16, 2024

@alan-cooney Ping on review :)

@alan-cooney
Copy link
Collaborator

Sorry folks!! I'm on it now

Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding this and the approach is great.

A general point (don't feel the need to change, but to be aware) is that we're trying to type all methods and also add docstrings to them (Google style). If you have a chance it would be good to do this as well to e.g. explain the interleave approach that is used.

demos/Mistral.ipynb Outdated Show resolved Hide resolved
tests/unit/test_grouped_query_attention.py Outdated Show resolved Hide resolved
@Felhof Felhof force-pushed the support-mistral-7B branch from e8ea289 to 383a031 Compare January 21, 2024 05:57
@Felhof
Copy link
Contributor Author

Felhof commented Jan 21, 2024

Thanks for the review Alan! I have added better typing and documentation and removed the demo

@alan-cooney alan-cooney merged commit 11edb28 into TransformerLensOrg:main Jan 22, 2024
8 checks passed
@alan-cooney
Copy link
Collaborator

Thanks! And thanks for adding this!

@neelnanda-io
Copy link
Collaborator

Woot! Really glad this got merged in! Thanks for adding it @Felhof and sorry for the long delay

@collingray collingray mentioned this pull request Jan 25, 2024
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Proposal] Support Mistral 7B model
6 participants