Skip to content

Commit

Permalink
hardcode mistral config so building docs works with older transformer…
Browse files Browse the repository at this point in the history
… versions
  • Loading branch information
Felhof committed Oct 27, 2023
1 parent ae27a64 commit 7da7b4f
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,11 +546,13 @@ def convert_hf_model_config(model_name: str, **kwargs):
# In case the user passed in an alias
official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
if "llama" not in official_model_name.lower():
if "llama" in official_model_name.lower():
architecture = "LlamaForCausalLM"
elif "mistral" in official_model_name.lower():
architecture = "MistralForCausalLM"
else:
hf_config = AutoConfig.from_pretrained(official_model_name, **kwargs)
architecture = hf_config.architectures[0]
else:
architecture = "LlamaForCausalLM"
if official_model_name.startswith(
("llama-7b", "Llama-2-7b")
): # same architecture for LLaMA and Llama-2
Expand Down Expand Up @@ -727,20 +729,20 @@ def convert_hf_model_config(model_name: str, **kwargs):
}
elif architecture == "MistralForCausalLM":
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": hf_config.max_position_embeddings,
"d_vocab": hf_config.vocab_size,
"act_fn": hf_config.hidden_act,
"d_model": 4096,
"d_head": 4096 // 32,
"n_heads": 32,
"d_mlp": 14336,
"n_layers": 32,
"n_ctx": 32768,
"d_vocab": 32000,
"act_fn": "silu",
"normalization_type": "RMS",
"positional_embedding_type": "rotary",
"window_size": hf_config.sliding_window,
"attn_types": ["local"] * hf_config.num_hidden_layers,
"eps": hf_config.rms_norm_eps,
"n_key_value_heads": hf_config.num_key_value_heads,
"window_size": 4096,
"attn_types": ["local"] * 32,
"eps": 1e-05,
"n_key_value_heads": 8,
"final_rms": True,
"gated_mlp": True,
"use_local_attn": True,
Expand Down Expand Up @@ -842,6 +844,8 @@ def get_pretrained_model_config(
"""
official_model_name = get_official_model_name(model_name)
print(f"Official name: {official_model_name}")
print()
if (
official_model_name.startswith("NeelNanda")
or official_model_name.startswith("ArthurConmy")
Expand Down

0 comments on commit 7da7b4f

Please sign in to comment.