Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Mar 22, 2024
1 parent 7c1cb6b commit 01978e6
Showing 1 changed file with 61 additions and 5 deletions.
66 changes: 61 additions & 5 deletions tests/unit/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import pytest
import torch
from datasets import Dataset
from torch import Tensor
from transformer_lens import HookedTransformer

from sae_training.activations_store import ActivationsStore
from sae_training.optim import get_scheduler
from sae_training.sae_group import SAEGroup
from sae_training.sparse_autoencoder import ForwardOutput, SparseAutoencoder
Expand All @@ -16,6 +19,7 @@
_log_feature_sparsity,
_save_checkpoint,
_train_step,
train_sae_group_on_language_model,
)
from tests.unit.helpers import build_sae_cfg

Expand Down Expand Up @@ -61,7 +65,7 @@ def modified_forward(*args: Any, **kwargs: Any):
return modified_forward


def test_train_step_reduces_loss_when_called_repeatedly_on_same_acts() -> None:
def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts() -> None:
cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0)
sae = SparseAutoencoder(cfg)
ctx = build_train_ctx(sae)
Expand Down Expand Up @@ -90,7 +94,7 @@ def test_train_step_reduces_loss_when_called_repeatedly_on_same_acts() -> None:
assert ctx.n_frac_active_tokens == 50 # should increment each step by batch_size


def test_train_step_output_looks_reasonable() -> None:
def test_train_step__output_looks_reasonable() -> None:
cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0)
sae = SparseAutoencoder(cfg)
ctx = build_train_ctx(sae)
Expand Down Expand Up @@ -128,7 +132,30 @@ def test_train_step_output_looks_reasonable() -> None:
)


def test_train_step_sparsity_updates_based_on_feature_act_sparsity() -> None:
def test_train_step__ghost_grads_mask() -> None:
cfg = build_sae_cfg(d_in=2, d_sae=4, dead_feature_window=5)
sae = SparseAutoencoder(cfg)
ctx = build_train_ctx(
sae, n_forward_passes_since_fired=torch.tensor([0, 4, 7, 9]).float()
)

output = _train_step(
sparse_autoencoder=sae,
ctx=ctx,
layer_acts=torch.randn(10, 1, 2),
all_layers=[0],
feature_sampling_window=1000,
use_wandb=False,
n_training_steps=10,
batch_size=10,
wandb_suffix="",
)
assert torch.all(
output.ghost_grad_neuron_mask == torch.Tensor([False, False, True, True])
)


def test_train_step__sparsity_updates_based_on_feature_act_sparsity() -> None:
cfg = build_sae_cfg(d_in=2, d_sae=4, hook_point_layer=0)
sae = SparseAutoencoder(cfg)

Expand Down Expand Up @@ -176,14 +203,14 @@ def test_train_step_sparsity_updates_based_on_feature_act_sparsity() -> None:
assert train_output.feature_acts is feature_acts


def test_log_feature_sparsity_handles_zeroes_by_default_fp32() -> None:
def test_log_feature_sparsity__handles_zeroes_by_default_fp32() -> None:
fp32_zeroes = torch.tensor([0], dtype=torch.float32)
assert _log_feature_sparsity(fp32_zeroes).item() != float("-inf")


# TODO: currently doesn't work for fp16, we should address this
@pytest.mark.skip(reason="Currently doesn't work for fp16")
def test_log_feature_sparsity_handles_zeroes_by_default_fp16() -> None:
def test_log_feature_sparsity__handles_zeroes_by_default_fp16() -> None:
fp16_zeroes = torch.tensor([0], dtype=torch.float16)
assert _log_feature_sparsity(fp16_zeroes).item() != float("-inf")

Expand Down Expand Up @@ -272,3 +299,32 @@ def test_save_checkpoint(tmp_path: Path) -> None:
assert torch.allclose(
loaded_log_sparsities[0], _log_feature_sparsity(ctx.feature_sparsity)
)


def test_train_sae_group_on_language_model__runs_and_outputs_look_reasonable(
ts_model: HookedTransformer,
tmp_path: Path,
) -> None:
checkpoint_dir = tmp_path / "checkpoint"
cfg = build_sae_cfg(
checkpoint_path=checkpoint_dir,
train_batch_size=32,
total_training_tokens=100,
context_size=8,
)
# just a tiny datast which will run quickly
dataset = Dataset.from_list([{"text": "hello world"}] * 1000)
activation_store = ActivationsStore(cfg, model=ts_model, dataset=dataset)
sae_group = SAEGroup(cfg)
res = train_sae_group_on_language_model(
model=ts_model,
sae_group=sae_group,
activation_store=activation_store,
batch_size=32,
)
assert res.checkpoint_paths == [
str(checkpoint_dir / f"final_{sae_group.get_name()}.pt")
]
assert len(res.log_feature_sparsities) == 1
assert res.log_feature_sparsities[0].shape == (cfg.d_sae,)
assert res.sae_group is sae_group

0 comments on commit 01978e6

Please sign in to comment.