Skip to content

Commit

Permalink
fix accidental bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 21, 2024
1 parent b5e97f8 commit c22fbbd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,12 @@ def load_from_pretrained_legacy(cls, path: str):

# Create an instance of the class using the loaded configuration
instance = cls(cfg=state_dict["cfg"])
new_state_dict = instance.state_dict()
if "scaling_factor" not in state_dict["state_dict"]:
assert isinstance(instance.cfg.d_sae, int)
state_dict["state_dict"]["scaling_factor"] = torch.ones(

Check warning on line 317 in sae_lens/training/sparse_autoencoder.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/sparse_autoencoder.py#L316-L317

Added lines #L316 - L317 were not covered by tests
instance.cfg.d_sae, dtype=instance.cfg.dtype, device=instance.cfg.device
)
instance.load_state_dict(new_state_dict, strict=True)
instance.load_state_dict(state_dict["state_dict"], strict=True)

return instance

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/toolkit/test_pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_convert_old_to_modern_saelens_format():

# convert file format
pretrained_saes.convert_old_to_modern_saelens_format(
legacy_out_file, new_out_folder
legacy_out_file, new_out_folder, force=True
)

# Load from new converted file
Expand Down

0 comments on commit c22fbbd

Please sign in to comment.