Skip to content

Commit

Permalink
suppourt apollo pretokenized datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Mar 20, 2024
1 parent 5acd89b commit e814054
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 84 deletions.
23 changes: 17 additions & 6 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@ def __init__(

# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)
self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys()
print(
f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config."
)

# check if it's tokenized
if "tokens" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.tokens_column = "tokens"
elif "input_ids" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.tokens_column = "input_ids"
elif "text" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = False
self.tokens_column = "text"
else:
raise ValueError(
"Dataset must have a 'tokens', 'input_ids', or 'text' column."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
Expand Down Expand Up @@ -79,7 +90,7 @@ def get_batch_tokens(self):
# pbar = tqdm(total=batch_size, desc="Filling batches")
while batch_tokens.shape[0] < batch_size:
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
s = next(self.iterable_dataset)[self.tokens_column]
tokens = self.model.to_tokens(
s,
truncate=True,
Expand All @@ -90,7 +101,7 @@ def get_batch_tokens(self):
), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
next(self.iterable_dataset)[self.tokens_column],
dtype=torch.long,
device=device,
requires_grad=False,
Expand Down
148 changes: 70 additions & 78 deletions tests/unit/test_activations_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Iterable
from types import SimpleNamespace
from typing import Any

import pytest
import torch
Expand All @@ -9,23 +8,55 @@

from sae_training.activations_store import ActivationsStore

TEST_MODEL = "tiny-stories-1M"
TEST_DATASET = "roneneldan/TinyStories"


@pytest.fixture
def cfg():
"""
Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
# Define a new fixture for different configurations
@pytest.fixture(
params=[
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 64,
},
{
"model_name": "gelu-2l",
"dataset_path": "NeelNanda/c4-tokenized-2b",
"tokenized": True,
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 512,
},
{
"model_name": "gpt2",
"dataset_path": "apollo-research/sae-monology-pile-uncopyrighted-tokenizer-gpt2",
"tokenized": True,
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 768,
},
{
"model_name": "gpt2",
"dataset_path": "Skylion007/openwebtext",
"tokenized": False,
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 768,
},
],
ids=["tiny-stories-1M", "gelu-2l-tokenized", "gpt2-tokenized", "gpt2"],
)
def cfg(request: pytest.FixtureRequest) -> SimpleNamespace:
# This function will be called with each parameter set
params = request.param
mock_config = SimpleNamespace()
mock_config.model_name = TEST_MODEL
mock_config.hook_point = "blocks.0.hook_mlp_out"
mock_config.hook_point_layer = 1
mock_config.dataset_path = TEST_DATASET
mock_config.is_dataset_tokenized = False
mock_config.d_in = 64
mock_config.model_name = params["model_name"]
mock_config.dataset_path = params["dataset_path"]
mock_config.is_dataset_tokenized = params["tokenized"]
mock_config.hook_point = params["hook_point"]
mock_config.hook_point_layer = params["hook_point_layer"]
mock_config.d_in = params["d_in"]
mock_config.expansion_factor = 2
mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor
mock_config.l1_coefficient = 2e-3
Expand All @@ -34,51 +65,6 @@ def cfg():
mock_config.context_size = 16
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = None
mock_config.lp_norm = 1

mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
mock_config.feature_reinit_scale = 0.1
mock_config.dead_feature_threshold = 1e-7

mock_config.n_batches_in_buffer = 4
mock_config.total_training_tokens = 1_000_000
mock_config.store_batch_size = 32

mock_config.log_to_wandb = False
mock_config.wandb_project = "test_project"
mock_config.wandb_entity = "test_entity"
mock_config.wandb_log_frequency = 10
mock_config.device = torch.device("cpu")
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
mock_config.dtype = torch.float32

return mock_config


@pytest.fixture
def cfg_head_hook():
"""
Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
mock_config = SimpleNamespace()
mock_config.model_name = TEST_MODEL
mock_config.hook_point = "blocks.0.attn.hook_q"
mock_config.hook_point_layer = 1
mock_config.hook_point_head_index = 2
mock_config.dataset_path = TEST_DATASET
mock_config.is_dataset_tokenized = False
mock_config.d_in = 4
mock_config.expansion_factor = 2
mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor
mock_config.l1_coefficient = 2e-3
mock_config.lr = 2e-4
mock_config.train_batch_size = 32
mock_config.context_size = 128
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = 0

mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
Expand All @@ -102,21 +88,23 @@ def cfg_head_hook():


@pytest.fixture
def model():
return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu")
def model(cfg: SimpleNamespace):
return HookedTransformer.from_pretrained(cfg.model_name, device="cpu")


@pytest.fixture
def activation_store(cfg: Any, model: HookedTransformer):
def activation_store(cfg: SimpleNamespace, model: HookedTransformer):
return ActivationsStore(cfg, model)


@pytest.fixture
def activation_store_head_hook(cfg_head_hook: Any, model: HookedTransformer):
def activation_store_head_hook(
cfg_head_hook: SimpleNamespace, model: HookedTransformer
):
return ActivationsStore(cfg_head_hook, model)


def test_activations_store__init__(cfg: Any, model: HookedTransformer):
def test_activations_store__init__(cfg: SimpleNamespace, model: HookedTransformer):
store = ActivationsStore(cfg, model)

assert store.cfg == cfg
Expand Down Expand Up @@ -149,23 +137,27 @@ def test_activations_store__get_batch_tokens(activation_store: ActivationsStore)
assert batch.device == activation_store.cfg.device


def test_activations_store__get_activations(activation_store: ActivationsStore):
def test_activations_score_get_next_batch(
model: HookedTransformer, activation_store: ActivationsStore
):

batch = activation_store.get_batch_tokens()
activations = activation_store.get_activations(batch)
assert batch.shape == (
activation_store.cfg.store_batch_size,
activation_store.cfg.context_size,
)

cfg = activation_store.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, 1, cfg.d_in)
assert activations.device == cfg.device
# if model.tokenizer.bos_token_id is not None:
# torch.testing.assert_close(
# batch[:, 0], torch.ones_like(batch[:, 0]) * model.tokenizer.bos_token_id
# )


def test_activations_store__get_activations_head_hook(
activation_store_head_hook: ActivationsStore,
):
batch = activation_store_head_hook.get_batch_tokens()
activations = activation_store_head_hook.get_activations(batch)
def test_activations_store__get_activations(activation_store: ActivationsStore):
batch = activation_store.get_batch_tokens()
activations = activation_store.get_activations(batch)

cfg = activation_store_head_hook.cfg
cfg = activation_store.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, 1, cfg.d_in)
assert activations.device == cfg.device
Expand Down

0 comments on commit e814054

Please sign in to comment.