Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reload existing llama checkpoints #305

Closed
tianyu-l opened this issue May 3, 2024 · 19 comments · Fixed by #634
Closed

reload existing llama checkpoints #305

tianyu-l opened this issue May 3, 2024 · 19 comments · Fixed by #634
Assignees
Labels
enhancement New feature or request release_blocking Issues that are blocking the milestone / release completion

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented May 3, 2024

No description provided.

@tianyu-l tianyu-l added the enhancement New feature or request label May 3, 2024
@Lauler
Copy link

Lauler commented May 11, 2024

Is this issue related to loading pretrained Llama2/Llama3 weights and using them as checkpoint?

I was going to start a separate issue asking for some docs that explain how to convert pretrained weights from HF to torchtitan in order to do continued pretraining. Is that already possible or on the roadmap?

@fegin
Copy link
Contributor

fegin commented May 13, 2024

DCP has the format util to help the conversion. However, HF conversion should not live in PyTorch code base.

@tianyu-l
Copy link
Contributor Author

@lessw2020 will connect with HF to see if they can support weights conversion from HF to pytorch. After that, we may import that in the code or update the tutorial.

@rlrs
Copy link
Contributor

rlrs commented May 21, 2024

I have a straightforward script for converting from HF to a DCP checkpoint, if that helps. Mostly the script already exists in gpt-fast.

@tianyu-l
Copy link
Contributor Author

@rlrs Thanks, pls feel free to share it here!

As far as we know, HF is also working on such a script to convert from HF to DCP. As discussed in #335, we should include a script to convert from llama raw weights into DCP (similar to the one here), and it probably should still sit in pytorch/pytorch.

@rlrs
Copy link
Contributor

rlrs commented May 24, 2024

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original.
I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

@kxgong
Copy link

kxgong commented Jun 9, 2024

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original. I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

Thanks for sharing.

@bkchang
Copy link

bkchang commented Jun 20, 2024

Is there a conversion in the other direction? Meaning converting a dcp checkpoint to an HF model? I found a util dcp_to_torch_save but am not sure how to go from there to a HF model.

@tianyu-l
Copy link
Contributor Author

@bkchang From HF website, there's a script to convert llama weights to HF format.

@bkchang
Copy link

bkchang commented Jun 24, 2024

@tianyu-l Thanks for the comment. Unfortunately, that script is for converting a llama model in the format it was first uploaded by the llama team. The script thus requires input files like params.json and tokenizer.model, and torchtitan doesn't generate these. What I would like to know is how to go from torchtitan output weights to a HF model. Thank you.

@casper-hansen
Copy link
Contributor

An example of how to reload the pretrained weights would be nice once we have the weights in DCP format (e.g. for continued pretraining).

@tianyu-l
Copy link
Contributor Author

An example of how to reload the pretrained weights would be nice once we have the weights in DCP format (e.g. for continued pretraining).

cc: @wz337

@rlrs
Copy link
Contributor

rlrs commented Oct 17, 2024

An example of how to reload the pretrained weights would be nice once we have the weights in DCP format (e.g. for continued pretraining).

Save the DCP checkpoint as step-0 and it will be loaded at the beginning of training.

@soumik-kanad
Copy link

soumik-kanad commented Oct 17, 2024

@rlrs Thank you so much for the scipt for conversion. But I'm slightly confused about one thing, ie. the need for permutation on my side -

  1. When using your script here (which uses the root and not the original) do we still need to update the apply_rotary_emb() according to this post or not?
  2. Or do we need to download the original llama weights and use the default apply_rotary_emb() function as written in this repo?

@casper-hansen
Copy link
Contributor

@soumik-kanad you will have to permute the weights to their original format if you want to use the current implementation.

I would appreciate if the TorchTitan team could show what they think is the best way for continued pretraining that's not hacky. Ideally, you should just be able to load in the original llama torch weights.

@tianyu-l
Copy link
Contributor Author

@casper-hansen

Ideally, you should just be able to load in the original llama torch weights.

Definitely this is some thing we should support. There have been a lot asks on it, but we are still trying to find the bandwidth to work on it. Alternatively, please feel free to make PRs on it, we can help review.

cc: @wz337 @fegin @wconstab @gnadathur

@rlrs
Copy link
Contributor

rlrs commented Oct 17, 2024

@rlrs Thank you so much for the scipt for conversion. But I'm slightly confused about one thing, ie. the need for permutation on my side -

1. When using your [script here](/~https://github.com/pytorch/torchtitan/issues/305#issuecomment-2129251951) (which uses the root and not the `original`) do we still need to update the `apply_rotary_emb()` according to [this post](/~https://github.com/pytorch/torchtitan/issues/335#issue-2298324053) or not?

2. Or do we need to download the original llama weights and use the default `apply_rotary_emb()` function as written in this repo?

You either use the original weights with the rope implementation as it's implemented in the original llama code (and here in torchtitan), or you use the converted HF weights and the HF rope implementation that you also link to.

It should be relatively easy to change the script I posted to load from the original llama weights. I can do it soon if no one else manages before I get around to it. @tianyu-l or @casper-hansen, how would you want this support in the repo? To me it already seems pretty straighforward, unsure what else is needed.

@tianyu-l
Copy link
Contributor Author

@rlrs It would be great if you can help contribute such a script, to convert checkpoints from original llama weights to DCP format.

I think we can put it under a scripts folder, with input / output directories as args. It might be fine to support converting from the HF Transformer format if the majority of code can be shared. If that's the case we can add an configurable option to choose from original vs. HF.

Besides, we need to create a (short) tutorial (maybe just in /~https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md) to illustrate how to convert and load, and possibly add unit tests under the test folder.

@tianyu-l tianyu-l added this to the torchtitan release 1.0 milestone Oct 18, 2024
@jaysonfrancis
Copy link
Contributor

happy to help on this if needed. Also confirming small deltas on KQ weights between original<->hf

@gnadathur gnadathur added the release_blocking Issues that are blocking the milestone / release completion label Oct 22, 2024
mori360 pushed a commit to mori360/torchtitan that referenced this issue Nov 26, 2024
Closes pytorch#305. 

Just wanted to get this out here quickly. 
The script is very simple since the weights are already in the
completely correct format, names and everything. All of the complexity
is avoided by not using HF, and so I believe that any functionality
relating to HF should be located on their side. However, I do have a DCP
-> HF export script which might be useful for some people, in case HF
does not have/add one.

I'll be happy to add any needed documentation or tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request release_blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

Successfully merging a pull request may close this issue.