Skip to content

Commit

Permalink
fixl0_plus_other_stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 3, 2023
1 parent 4cacbfc commit 2f162f0
Show file tree
Hide file tree
Showing 13 changed files with 561 additions and 270 deletions.
577 changes: 365 additions & 212 deletions dev.ipynb

Large diffs are not rendered by default.

Binary file added image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os

import einops
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer

from sae_training.lm_datasets import preprocess_tokenized_dataset


class DataLoaderBuffer:
class ActivationsStore:
"""
Class for streaming tokens and generating and storing activations
while training SAEs.
"""
def __init__(
self, cfg, model: HookedTransformer,
data_path="NeelNanda/c4-code-tokenized-2b",
Expand All @@ -22,7 +21,10 @@ def __init__(
self.is_dataset_tokenized = is_dataset_tokenized
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)
self.buffer = torch.zeros(0, self.cfg.d_in, device=self.cfg.device)

# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()

def get_batch_tokens(self):
"""
Expand Down Expand Up @@ -108,11 +110,10 @@ def get_activations(self, batch_tokens):

return activations

def get_buffer(self):
def get_buffer(self, n_batches_in_buffer):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
n_batches_in_buffer = self.cfg.n_batches_in_buffer
total_size = batch_size * n_batches_in_buffer

refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
Expand All @@ -137,3 +138,44 @@ def get_buffer(self):
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

return new_buffer

def get_data_loader(self,) -> DataLoader:
'''
Return a torch.utils.dataloader which you can get batches from.
Should automatically refill the buffer when it gets to n % full.
(better mixing if you refill and shuffle regularly).
'''

batch_size = self.cfg.train_batch_size

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer //2),
self.storage_buffer]
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[:mixing_buffer.shape[0]//2]

# 3. put other 50 % in a dataloader
dataloader = iter(DataLoader(mixing_buffer[:mixing_buffer.shape[0]//2:], batch_size=batch_size, shuffle=True))

return dataloader


def next_batch(self):
"""
Get the next batch from the current DataLoader.
If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
"""
try:
# Try to get the next batch
return next(self.dataloader)
except StopIteration:
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
return next(self.dataloader)
2 changes: 1 addition & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LanguageModelSAERunnerConfig:
# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
store_batch_size: int = 4096
store_batch_size: int = 1024

# WANDB
log_to_wandb: bool = True
Expand Down
Empty file added sae_training/tmp.py
Empty file.
46 changes: 13 additions & 33 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from transformer_lens import HookedTransformer

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder


def train_sae_on_language_model(
model: HookedTransformer,
sparse_autoencoder: SparseAutoencoder,
data_loader_buffer,
activation_store: ActivationsStore,
batch_size: int = 1024,
n_checkpoints: int = 0,
feature_sampling_method: str = "l2", # None, l2, or anthropic
Expand All @@ -35,11 +36,6 @@ def train_sae_on_language_model(

if n_checkpoints > 0:
checkpoint_thresholds = list(range(0, total_training_tokens, total_training_tokens // n_checkpoints))[1:]

# start the buffer
buffer = data_loader_buffer.get_buffer()
dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))
n_remaining_batches_in_buffer = len(dataloader)

pbar = tqdm(total=total_training_tokens, desc="Training SAE")
while n_training_tokens < total_training_tokens:
Expand All @@ -57,14 +53,14 @@ def train_sae_on_language_model(
feature_sampling_window * batch_size
)
# if standard resampling <- do this
n_resampled_neurons = sparse_autoencoder.resample_neurons(next(dataloader), feature_sparsity, feature_reinit_scale)
n_remaining_batches_in_buffer -= 1
n_resampled_neurons = sparse_autoencoder.resample_neurons(
activation_store.next_batch(),
feature_sparsity,
feature_reinit_scale)

# elif anthropic resampling <- do this
# run the model and reinit where recons loss is high.
if n_remaining_batches_in_buffer == 0:
dataloader, n_remaining_batches_in_buffer = get_new_dataloader(
data_loader_buffer, n_remaining_batches_in_buffer, batch_size)

else:
n_resampled_neurons = 0

Expand All @@ -74,14 +70,8 @@ def train_sae_on_language_model(

# Forward and Backward Passes
optimizer.zero_grad()
_, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(next(dataloader))
_, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(activation_store.next_batch())
n_training_tokens += batch_size
n_remaining_batches_in_buffer -= 1

# Update the buffer if we've run out of batches
if n_remaining_batches_in_buffer == 0:
dataloader, n_remaining_batches_in_buffer = get_new_dataloader(
data_loader_buffer, n_remaining_batches_in_buffer, batch_size)

with torch.no_grad():
# Calculate the sparsities, and add it to a list, calculate sparsity metrics
Expand All @@ -103,7 +93,7 @@ def train_sae_on_language_model(
)

# metrics for currents acts
l0 = (feature_acts > 0).float().sum(0).mean()
l0 = (feature_acts > 0).float().sum(1).mean()
l2_norm = torch.norm(feature_acts, dim=1).mean()

if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0):
Expand Down Expand Up @@ -146,8 +136,7 @@ def train_sae_on_language_model(
)

# Now we want the reconstruction loss.
recons_score, _, _, _ = get_recons_loss(
sparse_autoencoder, model, data_loader_buffer=data_loader_buffer, num_batches=5)
recons_score, _, _, _ = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=5)

wandb.log(
{
Expand Down Expand Up @@ -186,21 +175,12 @@ def train_sae_on_language_model(

return sparse_autoencoder


def get_new_dataloader(data_loader_buffer, n_remaining_batches_in_buffer, batch_size):
buffer = data_loader_buffer.get_buffer()
dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))
n_remaining_batches_in_buffer = len(dataloader) // 2 # only ever use half the buffer .
return dataloader, n_remaining_batches_in_buffer



@torch.no_grad()
def get_recons_loss(sparse_autoencder, model, data_loader_buffer, num_batches=5):
hook_point = data_loader_buffer.cfg.hook_point
def get_recons_loss(sparse_autoencder, model, activation_store, num_batches=5):
hook_point = activation_store.cfg.hook_point
loss_list = []
for _ in range(num_batches):
batch_tokens = data_loader_buffer.get_batch_tokens()
batch_tokens = activation_store.get_batch_tokens()
loss = model(batch_tokens, return_type="loss")

# mean_abl_loss = model.run_with_hooks(tokens, return_type="loss",
Expand Down
8 changes: 4 additions & 4 deletions sae_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformer_lens import HookedTransformer

from sae_training.activations_buffer import DataLoaderBuffer
from sae_training.activations_store import ActivationsStore
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand All @@ -20,7 +20,7 @@ def __init__(self, cfg: LanguageModelSAERunnerConfig):
self.cfg = cfg


def load_session(self) -> Tuple[HookedTransformer, SparseAutoencoder, DataLoaderBuffer]:
def load_session(self) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
'''
Loads a session for training a sparse autoencoder on a language model.
'''
Expand All @@ -32,7 +32,7 @@ def load_session(self) -> Tuple[HookedTransformer, SparseAutoencoder, DataLoader
return model, sparse_autoencoder, activations_loader

@classmethod
def load_session_from_pretrained(cls, path: str) -> Tuple[HookedTransformer, SparseAutoencoder, DataLoaderBuffer]:
def load_session_from_pretrained(cls, path: str) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
'''
Loads a session for analysing a pretrained sparse autoencoder.
'''
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_activations_loader(self, cfg: LanguageModelSAERunnerConfig, model: Hooke
Loads a DataLoaderBuffer for the activations of a language model.
'''

activations_loader = DataLoaderBuffer(
activations_loader = ActivationsStore(
cfg, model,
data_path=cfg.dataset_path,
is_dataset_tokenized=cfg.is_dataset_tokenized,
Expand Down
16 changes: 8 additions & 8 deletions tests/benchmark/test_language_model_sae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ def test_language_model_sae_runner_mlp_out():

# Data Generating Function (Model + Training Distibuion)
model_name = "gelu-2l",
hook_point = "blocks.1.hook_mlp_out",
hook_point_layer = 1,
hook_point = "blocks.0.hook_mlp_out",
hook_point_layer = 0,
d_in = 512,
dataset_path = "NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,

# SAE Parameters
expansion_factor = 32, # determines the dimension of the SAE.
expansion_factor = 64, # determines the dimension of the SAE.

# Training Parameters
lr = 1e-4,
l1_coefficient = 1e-4,
l1_coefficient = 3e-4,
train_batch_size = 4096,
context_size = 128,

# Activation Store Parameters
n_batches_in_buffer = 16,
total_training_tokens = 5_000_00 * 1,
n_batches_in_buffer = 24,
total_training_tokens = 5_000_00 * 100, # 15 minutes on an A100
store_batch_size = 32,

# Resampling protocol
feature_sampling_method = None,
feature_sampling_window = 500,
feature_sampling_method = 'l2',
feature_sampling_window = 1000, # would fire ~5 times on 500 million tokens
feature_reinit_scale = 0.2,
dead_feature_threshold = 1e-8,

Expand Down
Loading

0 comments on commit 2f162f0

Please sign in to comment.