Skip to content

Commit

Permalink
[Safetensors] Make sure metadata is saved (huggingface#2506)
Browse files Browse the repository at this point in the history
* [Safetensors] Make sure metadata is saved

* make style
  • Loading branch information
patrickvonplaten authored and mengfei25 committed Mar 27, 2023
1 parent 9e2aa65 commit 568a1c9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,6 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return

if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save

os.makedirs(save_directory, exist_ok=True)

model_to_save = self
Expand All @@ -310,7 +307,12 @@ def save_pretrained(
weights_name = _add_variant(weights_name, variant)

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
if safe_serialization:
safetensors.torch.save_file(
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(save_directory, weights_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")

Expand Down

0 comments on commit 568a1c9

Please sign in to comment.