Skip to content

Commit

Permalink
handle saes saved before groups
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Mar 20, 2024
1 parent fa6cc49 commit 5acd89b
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 6 deletions.
1 change: 0 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Optional, cast

import torch

import wandb


Expand Down
2 changes: 1 addition & 1 deletion sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 1 addition & 0 deletions sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import wandb

from sae_training.config import LanguageModelSAERunnerConfig

# from sae_training.activation_store import ActivationStore
Expand Down
8 changes: 8 additions & 0 deletions sae_training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,11 @@ def get_name(self):
layer_string = f"{layers[0]}"
sae_name = f"sae_group_{self.cfg.model_name}_{self.cfg.hook_point.format(layer=layer_string)}_{self.cfg.d_sae}"
return sae_name

def eval(self):
for ae in self.autoencoders:
ae.eval()

def train(self):
for ae in self.autoencoders:
ae.train()
2 changes: 1 addition & 1 deletion sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch

import wandb

from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, cast

import torch
import wandb
from torch.optim import Adam
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.evals import run_evals
from sae_training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder


Expand Down
17 changes: 16 additions & 1 deletion sae_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sae_training.activations_store import ActivationsStore
from sae_training.sae_group import SAEGroup
from sae_training.sparse_autoencoder import SparseAutoencoder


class LMSparseAutoencoderSessionloader:
Expand Down Expand Up @@ -48,7 +49,21 @@ def load_session_from_pretrained(
# cfg = torch.load(path, map_location="cpu")["cfg"]

sparse_autoencoders = SAEGroup.load_from_pretrained(path)
model, _, activations_loader = cls(sparse_autoencoders.cfg).load_session()

# hacky code to deal with old SAE saves
if type(sparse_autoencoders) is dict:
sparse_autoencoder = SparseAutoencoder(cfg=sparse_autoencoders["cfg"])
sparse_autoencoder.load_state_dict(sparse_autoencoders["state_dict"])
model, sparse_autoencoders, activations_loader = cls(
sparse_autoencoder.cfg
).load_session()
sparse_autoencoders.autoencoders[0] = sparse_autoencoder
elif type(sparse_autoencoders) is SAEGroup:
model, _, activations_loader = cls(sparse_autoencoders.cfg).load_session()
else:
raise ValueError(
"The loaded sparse_autoencoders object is neither an SAE dict nor a SAEGroup"
)

return model, sparse_autoencoders, activations_loader

Expand Down

0 comments on commit 5acd89b

Please sign in to comment.