Skip to content

Commit

Permalink
Merge pull request #87 from evanhanders/old_to_new
Browse files Browse the repository at this point in the history
Adds function that converts old .pt pretrained SAEs to new folder format
  • Loading branch information
jbloomAus authored Apr 20, 2024
2 parents 87be422 + 94f1fc1 commit 1cb1725
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
30 changes: 30 additions & 0 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import pathlib
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -135,6 +136,35 @@ def convert_connor_rob_sae_to_our_saelens_format(
return ae_alt


def convert_old_to_modern_saelens_format(
pytorch_file: str, out_folder: str = "", force: bool = False
):
"""
Reads a pretrained SAE from the old pickle-style SAELens .pt format, then saves a modern-format SAELens SAE.
Arguments:
----------
pytorch_file: str
Path of old format file to open.
out_folder: str, optional
Path where new SAE will be stored; if None, out_folder = pytorch_file with the '.pt' removed.
force: bool, optional
If out_folder already exists, this function will not save unless force=True.
"""
file_path = pathlib.Path(pytorch_file)
if out_folder == "":
out_f = file_path.parent / file_path.stem
else:
out_f = pathlib.Path(out_folder)
if (not force) and out_f.exists():
raise FileExistsError(f"{out_folder} already exists and force=False")
out_f.mkdir(exist_ok=True, parents=True)

# Load model & save in new format.
autoencoder = SparseAutoencoder.load_from_pretrained_legacy(str(file_path))
autoencoder.save_model(str(out_f))


def get_gpt2_small_ckrk_attn_out_saes() -> dict[str, SparseAutoencoder]:

REPO_ID = "ckkissane/attn-saes-gpt2-small-all-layers"
Expand Down
5 changes: 3 additions & 2 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,11 @@ def load_from_pretrained_legacy(cls, path: str):
)
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
state_dict = torch.load(
path, pickle_module=BackwardsCompatiblePickleClass
)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")

elif path.endswith(".pkl.gz"):
try:
with gzip.open(path, "rb") as f:
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/toolkit/test_pretrained_saes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pathlib
import shutil

import torch

from sae_lens.toolkit import pretrained_saes
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


def test_convert_old_to_modern_saelens_format():
out_dir = pathlib.Path("unit_test_tmp")
out_dir.mkdir(exist_ok=True)
legacy_out_file = str(out_dir / "test.pt")
new_out_folder = str(out_dir / "test")

# Make an SAE, save old version
cfg = LanguageModelSAERunnerConfig(
dtype=torch.float32,
hook_point="blocks.0.hook_mlp_out",
)
old_sae = SparseAutoencoder(cfg)
old_sae.save_model_legacy(legacy_out_file)

# convert file format
pretrained_saes.convert_old_to_modern_saelens_format(
legacy_out_file, new_out_folder
)

# Load from new converted file
new_sae = SparseAutoencoder.load_from_pretrained(new_out_folder)
shutil.rmtree(out_dir) # cleanup

# Test similarity
assert torch.allclose(new_sae.W_enc, old_sae.W_enc)
assert torch.allclose(new_sae.W_dec, old_sae.W_dec)

0 comments on commit 1cb1725

Please sign in to comment.