Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mistral 7 b #443

Merged
merged 27 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3fe1e47
add mistral model name and alias
Felhof Oct 24, 2023
4c342c0
add code for converting mistral config to hooked transformer config
Felhof Oct 24, 2023
f6bfba6
add function for converting mistral weights
Felhof Oct 24, 2023
9de861b
add GroupedQueryAttention
Felhof Oct 24, 2023
d953e70
add abstract base class for attention
Oct 26, 2023
d12de7f
adapt keyvaluecache if grouped query attention is used
Felhof Oct 26, 2023
914aa56
fix fold_value_biases when using grouped query attention
Felhof Oct 26, 2023
01d7d4d
Add unit test for grouped query attention
Felhof Oct 27, 2023
0ae41aa
Add demo notebook for Mistral
Felhof Oct 27, 2023
38f7607
merge from main and solve conflicts
Felhof Oct 27, 2023
e766313
fix formatting
Felhof Oct 27, 2023
7f734d5
add documentation for grouped query attention
Felhof Oct 27, 2023
be70ed5
update lock file
Felhof Oct 27, 2023
ae27a64
use Union instead of | for union types
Felhof Oct 27, 2023
2322bd8
hardcode mistral config so building docs works with older versions of…
Felhof Oct 27, 2023
0088c02
don't set final_rms in Mistral config
Felhof Nov 3, 2023
473da7a
make Mistral-7b's alias name consistent with other models
Felhof Nov 3, 2023
5b9a3fa
merge and fix conflicts
Felhof Nov 3, 2023
cfef128
update main demo notebook
Felhof Nov 3, 2023
cb83b3e
merge from main and fix conflicts
Felhof Nov 18, 2023
2afad28
merge from main, fix conflict in poetry.lock
Felhof Dec 1, 2023
cff0a86
require transformers>=3.34
Felhof Dec 1, 2023
9ede6f9
merge and fix conflicts
Felhof Jan 10, 2024
2203dff
improve docstrings and clarify test name for grouped query attention
Felhof Jan 21, 2024
88dc810
remove Mistral demo
Felhof Jan 21, 2024
f6c4939
merge from main and fix conflict in components.py
Felhof Jan 21, 2024
383a031
fix docstring format
Felhof Jan 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 292,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 293,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -103,32 +103,28 @@
},
{
"cell_type": "code",
"execution_count": 294,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div id=\"circuits-vis-7e4c8a75-1335\" style=\"margin: 15px 0;\"/>\n",
"<div id=\"circuits-vis-1f2a8687-9cd7\" style=\"margin: 15px 0;\"/>\n",
" <script crossorigin type=\"module\">\n",
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.0/dist/cdn/esm.js\";\n",
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
" render(\n",
" \"circuits-vis-7e4c8a75-1335\",\n",
" \"circuits-vis-1f2a8687-9cd7\",\n",
" Hello,\n",
" {\"name\": \"Neel\"}\n",
" )\n",
" </script>"
],
"text/plain": [
"<circuitsvis.utils.render.RenderedHTML at 0xffff10cc9f10>"
"<circuitsvis.utils.render.RenderedHTML at 0x7f21437f1c30>"
]
},
"execution_count": 294,
"metadata": {
"text/html": {
"Content-Type": "text/html"
}
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
Expand All @@ -140,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 295,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": 296,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -179,16 +175,16 @@
},
{
"cell_type": "code",
"execution_count": 297,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0xffff425948e0>"
"<torch.autograd.grad_mode.set_grad_enabled at 0x7f213de735e0>"
]
},
"execution_count": 297,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -254,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 299,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -263,7 +259,7 @@
},
{
"cell_type": "code",
"execution_count": 300,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1210,21 +1206,21 @@
},
{
"cell_type": "code",
"execution_count": 314,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blocks.0.attn.W_Q torch.Size([12, 768, 64])\n",
"blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
"blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
"blocks.0.attn.W_O torch.Size([12, 64, 768])\n",
"blocks.0.attn.b_Q torch.Size([12, 64])\n",
"blocks.0.attn.b_O torch.Size([768])\n",
"blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
"blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
"blocks.0.attn.b_K torch.Size([12, 64])\n",
"blocks.0.attn.b_V torch.Size([12, 64])\n",
"blocks.0.attn.b_O torch.Size([768])\n",
"blocks.0.mlp.W_in torch.Size([768, 3072])\n",
"blocks.0.mlp.b_in torch.Size([3072])\n",
"blocks.0.mlp.W_out torch.Size([3072, 768])\n",
Expand All @@ -1247,7 +1243,7 @@
},
{
"cell_type": "code",
"execution_count": 315,
"execution_count": 20,
"metadata": {},
"outputs": [
{
Expand Down
Loading