Skip to content

Commit

Permalink
Merge pull request #60 from chanind/improve-config-typing
Browse files Browse the repository at this point in the history
fixing config typing
  • Loading branch information
jbloomAus authored Apr 6, 2024
2 parents 773bc02 + 9be3445 commit b8fba4f
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 186 deletions.
181 changes: 128 additions & 53 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os
from typing import Any, Iterator, cast
from typing import Any, Iterator, Literal, TypeVar, cast

import torch
from datasets import (
Expand All @@ -12,6 +14,11 @@
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer

from sae_training.config import (
CacheActivationsRunnerConfig,
LanguageModelSAERunnerConfig,
)

HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset


Expand All @@ -21,75 +28,142 @@ class ActivationsStore:
while training SAEs.
"""

model: HookedTransformer
dataset: HfDataset
cached_activations_path: str | None
tokens_column: Literal["tokens", "input_ids", "text"]
hook_point_head_index: int | None

@classmethod
def from_config(
cls,
model: HookedTransformer,
cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
) -> "ActivationsStore":
cached_activations_path = cfg.cached_activations_path
# set cached_activations_path to None if we're not using cached activations
if (
isinstance(cfg, LanguageModelSAERunnerConfig)
and not cfg.use_cached_activations
):
cached_activations_path = None
return cls(
model=model,
dataset=dataset or cfg.dataset_path,
hook_point=cfg.hook_point,
hook_point_layers=listify(cfg.hook_point_layer),
hook_point_head_index=cfg.hook_point_head_index,
context_size=cfg.context_size,
d_in=cfg.d_in,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.total_training_tokens,
store_batch_size=cfg.store_batch_size,
train_batch_size=cfg.train_batch_size,
prepend_bos=cfg.prepend_bos,
device=cfg.device,
dtype=cfg.dtype,
cached_activations_path=cached_activations_path,
create_dataloader=create_dataloader,
)

def __init__(
self,
cfg: Any,
model: HookedTransformer,
dataset: HfDataset | None = None,
dataset: HfDataset | str,
hook_point: str,
hook_point_layers: list[int],
hook_point_head_index: int | None,
context_size: int,
d_in: int,
n_batches_in_buffer: int,
total_training_tokens: int,
store_batch_size: int,
train_batch_size: int,
prepend_bos: bool,
device: str | torch.device,
dtype: torch.dtype,
cached_activations_path: str | None = None,
create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
self.dataset = dataset or load_dataset(
cfg.dataset_path, split="train", streaming=True
self.dataset = (
load_dataset(dataset, split="train", streaming=True)
if isinstance(dataset, str)
else dataset
)
self.hook_point = hook_point
self.hook_point_layers = hook_point_layers
self.hook_point_head_index = hook_point_head_index
self.context_size = context_size
self.d_in = d_in
self.n_batches_in_buffer = n_batches_in_buffer
self.total_training_tokens = total_training_tokens
self.store_batch_size = store_batch_size
self.train_batch_size = train_batch_size
self.prepend_bos = prepend_bos
self.device = device
self.dtype = dtype
self.cached_activations_path = cached_activations_path

self.iterable_dataset = iter(self.dataset)

# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)

# check if it's tokenized
if "tokens" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.is_dataset_tokenized = True
self.tokens_column = "tokens"
elif "input_ids" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.is_dataset_tokenized = True
self.tokens_column = "input_ids"
elif "text" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = False
self.is_dataset_tokenized = False
self.tokens_column = "text"
else:
raise ValueError(
"Dataset must have a 'tokens', 'input_ids', or 'text' column."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
assert self.cfg.cached_activations_path is not None # keep pyright happy
if cached_activations_path is not None: # EDIT: load from multi-layer acts
assert self.cached_activations_path is not None # keep pyright happy
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cfg.cached_activations_path
), f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."
self.cached_activations_path
), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

self.next_cache_idx = 0 # which file to open next
self.next_idx_within_buffer = 0 # where to start reading from in that file

# Check that we have enough data on disk
first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt")
first_buffer = torch.load(f"{self.cached_activations_path}/0.pt")
buffer_size_on_disk = first_buffer.shape[0]
n_buffers_on_disk = len(os.listdir(self.cfg.cached_activations_path))
n_buffers_on_disk = len(os.listdir(self.cached_activations_path))
# Note: we're assuming all files have the same number of tokens
# (which seems reasonable imo since that's what our script does)
n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk
assert (
n_activations_on_disk > self.cfg.total_training_tokens
), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M."
n_activations_on_disk > self.total_training_tokens
), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but total_training_tokens is {self.total_training_tokens/1e6:.1f}M."

# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()

def get_batch_tokens(self):
"""
Streams a batch of tokens from a dataset.
"""

batch_size = self.cfg.store_batch_size
context_size = self.cfg.context_size
device = self.cfg.device
batch_size = self.store_batch_size
context_size = self.context_size
device = self.device

batch_tokens = torch.zeros(
size=(0, context_size), device=device, dtype=torch.long, requires_grad=False
Expand Down Expand Up @@ -124,7 +198,7 @@ def get_batch_tokens(self):
token_len -= space_left

# only add BOS if it's not already the first token
if self.cfg.prepend_bos:
if self.prepend_bos:
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
Expand Down Expand Up @@ -160,23 +234,19 @@ def get_activations(self, batch_tokens: torch.Tensor):
d_in may result from a concatenated head dimension.
"""
layers = (
self.cfg.hook_point_layer
if isinstance(self.cfg.hook_point_layer, list)
else [self.cfg.hook_point_layer]
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
layers = self.hook_point_layers
act_names = [self.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
prepend_bos=self.cfg.prepend_bos,
prepend_bos=self.prepend_bos,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
if self.hook_point_head_index is not None:
activations_list = [
act[:, :, self.cfg.hook_point_head_index] for act in activations_list
act[:, :, self.hook_point_head_index] for act in activations_list
]
elif activations_list[0].ndim > 3: # if we have a head dimension
# flatten the head dimension
Expand All @@ -190,39 +260,35 @@ def get_activations(self, batch_tokens: torch.Tensor):
return stacked_activations

def get_buffer(self, n_batches_in_buffer: int):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
context_size = self.context_size
batch_size = self.store_batch_size
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = (
len(self.cfg.hook_point_layer)
if isinstance(self.cfg.hook_point_layer, list)
else 1
) # Number of hook points or layers
num_layers = len(self.hook_point_layers) # Number of hook points or layers

if self.cfg.use_cached_activations:
if self.cached_activations_path is not None:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
dtype=self.dtype,
device=self.device,
)
n_tokens_filled = 0

# Assume activations for different layers are stored separately and need to be combined
while n_tokens_filled < buffer_size:
if not os.path.exists(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
f"{self.cached_activations_path}/{self.next_cache_idx}.pt"
):
print(
"\n\nWarning: Ran out of cached activation files earlier than expected."
)
print(
f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}."
)
if buffer_size % self.cfg.total_training_tokens != 0:
if buffer_size % self.total_training_tokens != 0:
print(
"This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens"
)
Expand All @@ -232,7 +298,7 @@ def get_buffer(self, n_batches_in_buffer: int):
return new_buffer

activations = torch.load(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
f"{self.cached_activations_path}/{self.next_cache_idx}.pt"
)
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
Expand All @@ -257,8 +323,8 @@ def get_buffer(self, n_batches_in_buffer: int):
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
dtype=self.dtype,
device=self.device,
)

for refill_batch_idx_start in refill_iterator:
Expand Down Expand Up @@ -286,11 +352,11 @@ def get_data_loader(
"""

batch_size = self.cfg.train_batch_size
batch_size = self.train_batch_size

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer],
[self.get_buffer(self.n_batches_in_buffer // 2), self.storage_buffer],
dim=0,
)

Expand Down Expand Up @@ -325,14 +391,14 @@ def next_batch(self):
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
device = self.cfg.device
if not self.cfg.is_dataset_tokenized:
device = self.device
if not self.is_dataset_tokenized:
s = next(self.iterable_dataset)[self.tokens_column]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
prepend_bos=self.cfg.prepend_bos,
prepend_bos=self.prepend_bos,
).squeeze(0)
assert (
len(tokens.shape) == 1
Expand All @@ -345,8 +411,17 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
requires_grad=False,
)
if (
not self.cfg.prepend_bos
not self.prepend_bos
and tokens[0] == self.model.tokenizer.bos_token_id # type: ignore
):
tokens = tokens[1:]
return tokens


T = TypeVar("T")


def listify(x: T | list[T]) -> list[T]:
if isinstance(x, list):
return x
return [x]
Loading

0 comments on commit b8fba4f

Please sign in to comment.