Skip to content

Commit

Permalink
remove load with session option
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 11, 2024
1 parent 16935ef commit 74926e1
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tqdm import tqdm

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


Expand Down Expand Up @@ -57,9 +56,18 @@ def get_gpt2_res_jb_saes() -> (
def convert_connor_rob_sae_to_our_saelens_format(
state_dict: dict[str, torch.Tensor],
config: dict[str, int | str],
return_session: bool = False,
device: str = "mps",
device: str = "cpu",
):
"""
# can get session like so.
model, ae_alt, activation_store = LMSparseAutoencoderSessionloader(
cfg
).load_sae_training_group_session()
next(iter(ae_alt))[1].load_state_dict(state_dict)
return model, ae_alt, activation_store
"""

expansion_factor = int(config["dict_size"]) // int(config["act_size"])

Expand All @@ -82,17 +90,9 @@ def convert_connor_rob_sae_to_our_saelens_format(
dtype=torch.float32,
)

# if we want model + act store, do this
if return_session:
model, ae_alt, activation_store = LMSparseAutoencoderSessionloader(
cfg
).load_sae_training_group_session()
next(iter(ae_alt))[1].load_state_dict(state_dict)
return model, ae_alt, activation_store
else:
ae_alt = SparseAutoencoder(cfg)
ae_alt.load_state_dict(state_dict)
return ae_alt
ae_alt = SparseAutoencoder(cfg)
ae_alt.load_state_dict(state_dict)
return ae_alt


def get_gpt2_small_ckrk_attn_out_saes() -> dict[str, SparseAutoencoder]:
Expand Down

0 comments on commit 74926e1

Please sign in to comment.