-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model init with HuggingFace model #743
Comments
cc: @weifengpy @mori360 |
👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀 |
It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)
Could you give more details on the safe_tensors as I could repro the huge memory cost. |
I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the
I downloaded the model.safetensors for the
I am trying to mimic TorchTitan's implementation but with a HuggingFace model
This is a simple repro of my implementation which can be run using:
The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call |
However, @fegin Please correct me if I'm wrong. Also, shall we update model.init_weight() in torchtitan in the process from model.init_weight() to checkpoint.load() to to init weight param by param? |
Yes, @mori360, as you have implemented this feature, OOM should be able to avoid with |
Hi, any progress here? What is the best practice to continue pretrain a HF model with torchtitan? |
@neeldani Regarding your orginal issue, for now, the easiest approach would be to:
Does this make sense? @mori360 @fegin @tianyu-l @huyiwen, please correct me if I missed anything. |
Thanks @yzhangcs
I think the key thing to do is to convert a HF checkpoint into a DCP checkpoint, like what this script does #305 (comment) I heard that DCP is going to support HF checkpointing format, but it may take some time to happen. |
@tianyu-l I just wrote one for medium/small-sized models /~https://github.com/fla-org/flame/blob/main/convert_hf_to_dcp.py |
I am writing a simple script to run FSDP2 (
fully_shard
) on thepythia-1b
model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try withpythia-2.8b
model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:This is not very straightforward since the shards expect
DTensors
when the weights are being loaded viaload_checkpoint_and_dispatch
. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.The text was updated successfully, but these errors were encountered: