Skip to content

Commit

Permalink
adding some unit tests to _train_step()
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Mar 21, 2024
1 parent 2d5ec98 commit dbf3f01
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 6 deletions.
1 change: 1 addition & 0 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, cast

import torch

import wandb


Expand Down
2 changes: 1 addition & 1 deletion sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 0 additions & 1 deletion sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, cast

import wandb

from sae_training.config import LanguageModelSAERunnerConfig

# from sae_training.activation_store import ActivationStore
Expand Down
2 changes: 1 addition & 1 deletion sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch
import wandb

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.evals import run_evals
from sae_training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder


Expand Down
1 change: 0 additions & 1 deletion tests/benchmark/test_language_model_sae_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from transformer_lens import HookedTransformer

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from torch import Tensor

from sae_training.optim import get_scheduler
from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.train_sae_on_language_model import SAETrainContext, _train_step
from tests.unit.helpers import build_sae_cfg


def build_train_ctx(
sae: SparseAutoencoder,
act_freq_scores: Tensor | None = None,
n_forward_passes_since_fired: Tensor | None = None,
n_frac_active_tokens: int = 0,
) -> SAETrainContext:
"""
Factory helper to build a default SAETrainContext object.
"""
assert sae.cfg.d_sae is not None
optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr)
return SAETrainContext(
act_freq_scores=(
torch.zeros(sae.cfg.d_sae) if act_freq_scores is None else act_freq_scores
),
n_forward_passes_since_fired=(
torch.zeros(sae.cfg.d_sae)
if n_forward_passes_since_fired is None
else n_forward_passes_since_fired
),
n_frac_active_tokens=n_frac_active_tokens,
optimizer=optimizer,
scheduler=get_scheduler(None, optimizer=optimizer),
)


def test_train_step_reduces_loss_when_called_repeatedly_on_same_acts() -> None:
cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0)
sae = SparseAutoencoder(cfg)
ctx = build_train_ctx(sae)

layer_acts = torch.randn(10, 1, 64)

# intentionally train on the same activations 5 times to ensure loss decreases
train_outputs = [
_train_step(
sparse_autoencoder=sae,
ctx=ctx,
layer_acts=layer_acts,
all_layers=[0],
feature_sampling_window=1000,
use_wandb=False,
n_training_steps=10,
batch_size=10,
wandb_suffix="",
)
for _ in range(5)
]

# ensure loss decreases with each training step
for output, next_output in zip(train_outputs[:-1], train_outputs[1:]):
assert output.loss > next_output.loss
assert ctx.n_frac_active_tokens == 50 # should increment each step by batch_size


def test_train_step_output_looks_reasonable() -> None:
cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0, dead_feature_window=100)
sae = SparseAutoencoder(cfg)
ctx = build_train_ctx(sae)

layer_acts = torch.randn(10, 2, 64)

output = _train_step(
sparse_autoencoder=sae,
ctx=ctx,
layer_acts=layer_acts,
all_layers=[0],
feature_sampling_window=1000,
use_wandb=False,
n_training_steps=10,
batch_size=10,
wandb_suffix="",
)

assert output.loss > 0
# only hook_point_layer=0 acts should be passed to the SAE
assert torch.allclose(output.sae_in, layer_acts[:, 0, :])
assert output.sae_out.shape == output.sae_in.shape
assert output.feature_acts.shape == (10, 128) # batch_size, d_sae
assert output.ghost_grad_neuron_mask.shape == (128,)
# ghots grads shouldn't trigger until dead_feature_window, which hasn't been reached yet
assert torch.all(output.ghost_grad_neuron_mask == False) # noqa
assert output.ghost_grad_loss == 0
assert ctx.n_frac_active_tokens == 10
assert ctx.act_freq_scores.sum() > 0 # at least SOME acts should have fired
assert torch.allclose(
ctx.act_freq_scores, (output.feature_acts.abs() > 0).float().sum(0)
)

0 comments on commit dbf3f01

Please sign in to comment.