-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reload existing llama checkpoints #305
Comments
Is this issue related to loading pretrained Llama2/Llama3 weights and using them as checkpoint? I was going to start a separate issue asking for some docs that explain how to convert pretrained weights from HF to |
DCP has the format util to help the conversion. However, HF conversion should not live in PyTorch code base. |
@lessw2020 will connect with HF to see if they can support weights conversion from HF to pytorch. After that, we may import that in the code or update the tutorial. |
I have a straightforward script for converting from HF to a DCP checkpoint, if that helps. Mostly the script already exists in gpt-fast. |
Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from maester.models import models_config
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Path,
output_dir: Path,
) -> None:
# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "model.safetensors.index.json"
assert model_map_json.is_file()
with open(model_map_json, 'r') as json_map:
bin_index = json.load(json_map)
weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
merged_result = {}
for file in sorted(bin_files):
with safe_open(file, framework="pt", device="cpu") as f:
for k in f.keys():
merged_result[k] = f.get_tensor(k)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]
final_result[new_key] = value
output_dir.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
DCP.save({"model": final_result},
storage_writer=storage_writer)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
parser.add_argument('--checkpoint', type=Path, required=True)
parser.add_argument('--output', type=Path, required=True)
args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint,
output_dir=args.output,
) |
Thanks for sharing. |
Is there a conversion in the other direction? Meaning converting a dcp checkpoint to an HF model? I found a util dcp_to_torch_save but am not sure how to go from there to a HF model. |
@tianyu-l Thanks for the comment. Unfortunately, that script is for converting a llama model in the format it was first uploaded by the llama team. The script thus requires input files like params.json and tokenizer.model, and torchtitan doesn't generate these. What I would like to know is how to go from torchtitan output weights to a HF model. Thank you. |
An example of how to reload the pretrained weights would be nice once we have the weights in DCP format (e.g. for continued pretraining). |
cc: @wz337 |
Save the DCP checkpoint as step-0 and it will be loaded at the beginning of training. |
@rlrs Thank you so much for the scipt for conversion. But I'm slightly confused about one thing, ie. the need for permutation on my side -
|
@soumik-kanad you will have to permute the weights to their original format if you want to use the current implementation. I would appreciate if the TorchTitan team could show what they think is the best way for continued pretraining that's not hacky. Ideally, you should just be able to load in the original llama torch weights. |
Definitely this is some thing we should support. There have been a lot asks on it, but we are still trying to find the bandwidth to work on it. Alternatively, please feel free to make PRs on it, we can help review. |
You either use the It should be relatively easy to change the script I posted to load from the original llama weights. I can do it soon if no one else manages before I get around to it. @tianyu-l or @casper-hansen, how would you want this support in the repo? To me it already seems pretty straighforward, unsure what else is needed. |
@rlrs It would be great if you can help contribute such a script, to convert checkpoints from original llama weights to DCP format. I think we can put it under a Besides, we need to create a (short) tutorial (maybe just in /~https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md) to illustrate how to convert and load, and possibly add unit tests under the |
happy to help on this if needed. Also confirming small deltas on KQ weights between original<->hf |
Closes pytorch#305. Just wanted to get this out here quickly. The script is very simple since the weights are already in the completely correct format, names and everything. All of the complexity is avoided by not using HF, and so I believe that any functionality relating to HF should be located on their side. However, I do have a DCP -> HF export script which might be useful for some people, in case HF does not have/add one. I'll be happy to add any needed documentation or tests.
No description provided.
The text was updated successfully, but these errors were encountered: