Skip to content

Commit

Permalink
fix load pretrained legacy with state dict change
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 21, 2024
1 parent fdf7fe9 commit b5e97f8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,13 @@ def load_from_pretrained_legacy(cls, path: str):

# Create an instance of the class using the loaded configuration
instance = cls(cfg=state_dict["cfg"])
instance.load_state_dict(state_dict["state_dict"], strict=False)
new_state_dict = instance.state_dict()
if "scaling_factor" not in state_dict["state_dict"]:
assert isinstance(instance.cfg.d_sae, int)
state_dict["state_dict"]["scaling_factor"] = torch.ones(
instance.cfg.d_sae, dtype=instance.cfg.dtype, device=instance.cfg.device
)
instance.load_state_dict(new_state_dict, strict=True)

return instance

Expand Down
37 changes: 37 additions & 0 deletions tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,43 @@ def test_SparseAutoencoder_save_and_load_from_pretrained(tmp_path: Path) -> None
)


def test_SparseAutoencoder_save_and_load_from_pretrained_lacks_scaling_factor(
tmp_path: Path,
) -> None:
cfg = build_sae_cfg(device="cpu")
model_path = str(tmp_path)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
# sometimes old state dicts will be missing the scaling factor
del sparse_autoencoder_state_dict["scaling_factor"] # = torch.tensor(0.0)
sparse_autoencoder.save_model(model_path)

assert os.path.exists(model_path)

sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(model_path)
sparse_autoencoder_loaded.cfg.verbose = True
sparse_autoencoder_loaded.cfg.checkpoint_path = cfg.checkpoint_path
sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu")
sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict()
# check cfg matches the original
assert sparse_autoencoder_loaded.cfg == cfg

# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
if key == "scaling_factor":
assert isinstance(cfg.d_sae, int)
assert torch.allclose(
torch.ones(cfg.d_sae, dtype=cfg.dtype, device=cfg.device),
sparse_autoencoder_loaded_state_dict[key],
)
else:
assert torch.allclose(
sparse_autoencoder_state_dict[key],
sparse_autoencoder_loaded_state_dict[key],
)


def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder):
batch_size = 32
d_in = sparse_autoencoder.d_in
Expand Down

0 comments on commit b5e97f8

Please sign in to comment.