Skip to content

Commit

Permalink
[fix] Save custom module kwargs if specified (#3112)
Browse files Browse the repository at this point in the history
* Save custom module kwargs if specified

This should have been included in the save all along

* Also try to load a 'dynamic module' if not trust-remote, but local model
  • Loading branch information
tomaarsen authored Jan 6, 2025
1 parent 5dfd360 commit 4d80c88
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,11 @@ def save(
# For other cases, we want to add the class name:
elif not class_ref.startswith("sentence_transformers."):
class_ref = f"{class_ref}.{type(module).__name__}"
modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref})

module_config = {"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref}
if self.module_kwargs and name in self.module_kwargs and (module_kwargs := self.module_kwargs[name]):
module_config["kwargs"] = module_kwargs
modules_config.append(module_config)

with open(os.path.join(path, "modules.json"), "w") as fOut:
json.dump(modules_config, fOut, indent=2)
Expand Down Expand Up @@ -1556,7 +1560,7 @@ def _load_module_class_from_ref(
if class_ref.startswith("sentence_transformers."):
return import_from_string(class_ref)

if trust_remote_code:
if trust_remote_code or os.path.exists(model_name_or_path):
code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
try:
return get_class_from_dynamic_module(
Expand Down

0 comments on commit 4d80c88

Please sign in to comment.