-
Notifications
You must be signed in to change notification settings - Fork 318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support mistral 7 b #443
Support mistral 7 b #443
Conversation
7da7b4f
to
2322bd8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for doing this! I left some small comments, but this isn't a full review (I haven't looked at exactly how you implemented grouped query attention), so I'll leave that for someone else to finish off.
cf3245b
to
473da7a
Compare
d35a5a2
to
5b9a3fa
Compare
Just to flag I'll review this once the previous PR goes through on attention, as we'll need to adjust for a few conflicts in it |
Hey @Felhof nice work on this, and the abstract attention class approach is great (we're planning on doing this for a bunch of the components). Would you be able to fix the merge conflicts first and then I can do a full review? There's a few things there as the last PR to go through had some attention component changes. |
Hey! I'm curious what the status of this PR is? A few of my MATS scholars want to use Mistral. Can they just check out this PR? |
That worked for me, modulo the comment I made above about using the right version of transformers (and the tokenizer being disgusting) |
8f0b06f
to
cff0a86
Compare
@alan-cooney @neelnanda-io @Felhof I want to run an activation patching experiment on llama-70b, so I'm going to check out the Edit: I couldn't get it to work with some other dependencies I had, but I'll try again later this week. |
@alan-cooney @Felhof What's the status of this PR? I'm not sure which of you it's blocked on. Either way, I hear that people have been able to checkout this branch and get Mistral working, so thanks a lot for the work up to that point! |
f27d05d
to
9ede6f9
Compare
The branch has been working for a while but needs approval to be merged into main :) I made sure it's up-to-date with main again. |
@alan-cooney Ping on review :) |
Sorry folks!! I'm on it now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for adding this and the approach is great.
A general point (don't feel the need to change, but to be aware) is that we're trying to type all methods and also add docstrings to them (Google style). If you have a chance it would be good to do this as well to e.g. explain the interleave approach that is used.
e8ea289
to
383a031
Compare
Thanks for the review Alan! I have added better typing and documentation and removed the demo |
Thanks! And thanks for adding this! |
Woot! Really glad this got merged in! Thanks for adding it @Felhof and sorry for the long delay |
Description
This PR closes #387 by adding support for Mistral and implementing Grouped Query Attention.
Mistral 7B
The weight's from Huggingface's Mistral-7B-v0.1 can now be loaded into a HookedTransformer using
Note that Mistral is only supported in transformers >= 4.34 and hence Python >= 3.8 is required to use it.
The demo notebook
Mistral.ipynb
features a comparison of Huggingface's Mistral with the HookedTransformer implementation. I tested it on the same prompts that were used in the Llama demo. The differences in the resulting logits were around 0.01 which is slightly higher than what Llama2 gets. This may be due to a known issue with rotary embeddings which also affects LLama2 and Pythia.Grouped Query Attention
Mistral utilizes Grouped Query Attention (GPA) which is not used by any other model supported by TransformerLens and had to be implemented. I added a new class
GroupedQueryAttention
and an abstract base classAbstractAttention
which now features the common functionality of both attention classes. The difference betweenAttention
andGroupedQueryAttention
is how they handle the key and value projections, as in GPA groups of queries share the same keys and values (see image below). This mostly affects the internal workings of the class. To avoid confusion, to not break existing code interacting with attention, and to make the design of future code easier, public attributes such asW_K
andW_V
have the same shape for both classes. This is because in the case of GPA, the underlying parameters are hidden behind a property that expands them usingtorch.repeat_interleave
. A GPA block should behave the same as a regular Attention block whose weights are the result of applyingtorch.repeat_interleave
to the GPA block's weight. There is a unit test that confirms this.Type of change
Please delete options that are not relevant.
Checklist: