Skip to content

Commit

Permalink
[Safetensors] Make sure metadata is saved (huggingface#2506)
Browse files Browse the repository at this point in the history
* [Safetensors] Make sure metadata is saved

* make style
  • Loading branch information
patrickvonplaten authored and Jimmy committed Apr 26, 2024
1 parent 9d823a1 commit 04e64eb
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,6 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save

os.makedirs(save_directory, exist_ok=True)

model_to_save = self
Expand All @@ -310,7 +307,12 @@ def save_pretrained(
weights_name = _add_variant(weights_name, variant)

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
if safe_serialization:
safetensors.torch.save_file(
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(save_directory, weights_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")

Expand Down

0 comments on commit 04e64eb

Please sign in to comment.