From 836298a897e759e1f99a35c7d8195bcfb580afe3 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 30 Nov 2023 13:57:01 +0000 Subject: [PATCH] happy with hyperpars on benchmark --- .pylintrc | 2 +- sae_training/SAE.py | 66 ++++++++++++++------- sae_training/toy_model_runner.py | 15 +++-- sae_training/train_sae.py | 93 +++++++++++++++++++++--------- tests/benchmark/test_sae_runner.py | 28 +++++---- 5 files changed, 141 insertions(+), 63 deletions(-) diff --git a/.pylintrc b/.pylintrc index 89f2a294..347171ef 100644 --- a/.pylintrc +++ b/.pylintrc @@ -16,4 +16,4 @@ default-docstring-type = numpy max-line-length = 88 [MESSAGES CONTROL] -disable = C0330, C0326, C0199, C0411 \ No newline at end of file +disable = C0330, C0326, C0199, C0411, C103 \ No newline at end of file diff --git a/sae_training/SAE.py b/sae_training/SAE.py index 8b4292d3..69def7fe 100644 --- a/sae_training/SAE.py +++ b/sae_training/SAE.py @@ -7,7 +7,9 @@ import einops import torch -from torch import nn +from jaxtyping import Float, Int +from torch import Tensor, nn +from torch.distributions.categorical import Categorical from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -93,22 +95,46 @@ def forward(self, x, return_mode: Literal["sae_out", "hidden_post", "both"]="bot else: raise ValueError(f"Unexpected {return_mode=}") - def reinit_neurons(self, indices): - new_W_enc = torch.nn.init.kaiming_uniform_( - torch.empty( - self.d_in, indices.shape[0], dtype=self.dtype, device=self.device - ) - ) * self.cfg["resample_factor"] - new_b_enc = torch.zeros( - indices.shape[0], dtype=self.dtype, device=self.device - ) - new_W_dec = torch.nn.init.kaiming_uniform_( - torch.empty( - indices.shape[0], self.d_in, dtype=self.dtype, device=self.get_test_lossevice - ) - ) - self.W_enc.data[:, indices] = new_W_enc - self.b_enc.data[indices] = new_b_enc - self.W_dec.data[indices, :] = new_W_dec - self.W_dec /= torch.norm(self.W_dec, dim=1, keepdim=True) - + @torch.no_grad() + def resample_neurons( + self, + x: Float[Tensor, "batch_size n_hidden"], + frac_active_in_window: Float[Tensor, "window n_hidden_ae"], + neuron_resample_scale: float, + ) -> None: + ''' + Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`. + ''' + sae_out = self.forward(x, return_mode="sae_out") + per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze() + + # Find the dead neurons in this instance. If all neurons are alive, continue + is_dead = (frac_active_in_window.sum(0) < 1e-8) + dead_neurons = torch.nonzero(is_dead).squeeze(-1) + alive_neurons = torch.nonzero(~is_dead).squeeze(-1) + n_dead = dead_neurons.numel() + + if n_dead == 0: + return # If there are no dead neurons, we don't need to resample neurons + + # Compute L2 loss for each element in the batch + # TODO: Check whether we need to go through more batches as features get sparse to find high l2 loss examples. + if per_token_l2_loss.max() < 1e-6: + return # If we have zero reconstruction loss, we don't need to resample neurons + + # Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss + distn = Categorical(probs = per_token_l2_loss / per_token_l2_loss.sum()) + replacement_indices = distn.sample((n_dead,)) # shape [n_dead] + + # Index into the batch of hidden activations to get our replacement values + replacement_values = (x - self.b_dec)[replacement_indices] # shape [n_dead n_input_ae] + + # Get the norm of alive neurons (or 1.0 if there are no alive neurons) + W_enc_norm_alive_mean = 1.0 if len(alive_neurons) == 0 else self.W_enc[:, alive_neurons].norm(dim=0).mean().item() + + # Use this to renormalize the replacement values + replacement_values = (replacement_values / (replacement_values.norm(dim=1, keepdim=True) + 1e-8)) * W_enc_norm_alive_mean * neuron_resample_scale + + # Lastly, set the new weights & biases + self.W_enc.data[:, dead_neurons] = replacement_values.T.squeeze(1) + self.b_enc.data[dead_neurons] = 0.0 diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py index 3526b681..c0af6a66 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_training/toy_model_runner.py @@ -22,17 +22,21 @@ class SAEToyModelRunnerConfig: # Relu Model Training Parameters model_training_steps: int = 10_000 # SAE Parameters - expansion_factor: int = 4 + d_sae: int = 5 # Training Parameters n_sae_training_tokens: int = 25_000 l1_coefficient: float = 1e-3 lr: float = 3e-4 - train_batch_size: int = 32 # Shouldn't be as big as the batch size for language models + train_batch_size: int = 1024 # Shouldn't be as big as the batch size for language models train_epochs: int = 10 + feature_sampling_window: int = 100 + feature_reinit_scale: float = 0.2 + dead_feature_threshold: float = 1e-8 # WANDB log_to_wandb: bool = True wandb_project: str = "mats_sae_training_toy_model" wandb_entity: str = None + wandb_log_frequency: int = 50 # Misc device: str = "cpu" seed: int = 42 @@ -43,8 +47,6 @@ class SAEToyModelRunnerConfig: def __post_init__(self): self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE - self.d_sae = self.n_hidden * self.expansion_factor - def toy_model_sae_runner(cfg): ''' @@ -83,12 +85,17 @@ def toy_model_sae_runner(cfg): wandb.init(project="sae-training-test", config=cfg) sae = train_sae( + model, # need model so we can do evals for neuron resampling sae, hidden.detach().squeeze(), use_wandb=cfg.log_to_wandb, l1_coeff=cfg.l1_coefficient, batch_size=cfg.train_batch_size, n_epochs=cfg.train_epochs, + feature_sampling_window=cfg.feature_sampling_window, + feature_reinit_scale=cfg.feature_reinit_scale, + dead_feature_threshold=cfg.dead_feature_threshold, + wandb_log_frequency=cfg.wandb_log_frequency, ) if cfg.log_to_wandb: diff --git a/sae_training/train_sae.py b/sae_training/train_sae.py index fae97ee2..908f33e1 100644 --- a/sae_training/train_sae.py +++ b/sae_training/train_sae.py @@ -7,59 +7,98 @@ import wandb from sae_training.activation_store import ActivationStore from sae_training.SAE import SAE +from sae_training.toy_models import Model as ToyModel #%% -def train_sae(sae: SAE, +def train_sae(model: ToyModel, + sae: SAE, activation_store: ActivationStore, n_epochs: int = 10, - batch_size: int = 32, + batch_size: int = 1024, l1_coeff: float = 0.001, + feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead + feature_reinit_scale: float = 0.2, # how much to scale the resampled features by + dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead use_wandb: bool = False, - wandb_log_freq: int = 10,): + wandb_log_frequency: int = 50,): """ Takes an SAE and a bunch of activations and does a bunch of training steps """ dataloader = DataLoader(activation_store, batch_size=batch_size, shuffle=True) - optimizer = torch.optim.Adam(sae.parameters()) + frac_active_list = [] # track active features sae.train() n_training_steps = 0 for epoch in range(n_epochs): pbar = tqdm(dataloader) for step, batch in enumerate(pbar): - optimizer.zero_grad() - sae_out, hidden_post = sae(batch) + + # Make sure the W_dec is still zero-norm + sae.W_dec.data /= (torch.norm(sae.W_dec.data, dim=1, keepdim=True) + 1e-8) + + # Resample dead neurons + if (feature_sampling_window is not None) and ((step + 1) % feature_sampling_window == 0): + + # Get the fraction of neurons active in the previous window + frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0) + + # Compute batch of hidden activations which we'll use in resampling + resampling_batch = model.generate_batch(batch_size) + + # Our version of running the model + hidden = einops.einsum( + resampling_batch, + model.W, + "batch_size instances features, instances hidden features -> batch_size instances hidden", + ) + + # Resample + sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale) + + + # Update learning rate here if using scheduler. + + # Forward and Backward Passes + optimizer.zero_grad() + sae_out, feature_acts = sae(batch) # loss = reconstruction MSE + L1 regularization mse_loss = ((sae_out - batch)**2).mean() - l1_loss = torch.abs(hidden_post).sum() + l1_loss = torch.abs(feature_acts).sum() loss = mse_loss + l1_coeff * l1_loss with torch.no_grad(): - batch_size = batch.shape[0] - frac_feature_activation = (hidden_post > 0).float().mean(0) - log_frac_feature_activation = torch.log(frac_feature_activation + 1e-8) - n_dead_features = (frac_feature_activation > 0).sum() - l0 = ((hidden_post != 0) / batch_size).sum() - l2_norm = torch.norm(hidden_post, dim=1).mean() + # Calculate the sparsities, and add it to a list + frac_active = einops.reduce( + (feature_acts.abs() > dead_feature_threshold).float(), + "batch_size hidden_ae -> hidden_ae", "mean") + frac_active_list.append(frac_active) + batch_size = batch.shape[0] + log_frac_feature_activation = torch.log(frac_active + 1e-8) + n_dead_features = (frac_active < dead_feature_threshold).sum() - if use_wandb and (step % wandb_log_freq == 0): - wandb.log({ - "losses/mse_loss": mse_loss.item(), - "losses/l1_loss": l1_loss.item(), - "losses/overall_loss": loss.item(), - "metrics/l0": l0.item(), - "metrics/l2": l2_norm.item(), - "metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()), - "metrics/n_dead_features": n_dead_features, - }, step=n_training_steps) + l0 = (feature_acts > 0).float().mean() + l2_norm = torch.norm(feature_acts, dim=1).mean() - pbar.set_description(f"{epoch}/{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}") + + if use_wandb and ((step + 1) % wandb_log_frequency == 0): + wandb.log({ + "losses/mse_loss": mse_loss.item(), + "losses/l1_loss": batch_size*l1_loss.item(), + "losses/overall_loss": loss.item(), + "metrics/l0": l0.item(), + "metrics/l2": l2_norm.item(), + # "metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()), + "metrics/n_dead_features": n_dead_features, + "metrics/n_alive_features": sae.d_sae - n_dead_features, + }, step=n_training_steps) + + pbar.set_description(f"{epoch}/{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}") loss.backward() @@ -81,10 +120,8 @@ def train_sae(sae: SAE, optimizer.step() - # Make sure the W_dec is still zero-norm - with torch.no_grad(): - sae.W_dec.data /= (torch.norm(sae.W_dec.data, dim=1, keepdim=True) + 1e-8) - + + n_training_steps += 1 diff --git a/tests/benchmark/test_sae_runner.py b/tests/benchmark/test_sae_runner.py index afc06e3e..500b78b3 100644 --- a/tests/benchmark/test_sae_runner.py +++ b/tests/benchmark/test_sae_runner.py @@ -8,18 +8,26 @@ def test_toy_model_sae_runner(): - cfg = SAEToyModelRunnerConfig( - n_features = 5, - n_hidden = 2, - n_correlated_pairs = 0, - n_anticorrelated_pairs = 0, - feature_probability = 0.025, - model_training_steps = 10_000, - n_sae_training_tokens = 50_000, - log_to_wandb = True, + n_features=5, + n_hidden=2, + n_correlated_pairs=0, + n_anticorrelated_pairs=0, + feature_probability=0.025, + # SAE Parameters + d_sae=5, + l1_coefficient=0.005, + # SAE Train Config + train_batch_size=1024, + feature_sampling_window=3_000, + feature_reinit_scale=0.5, + model_training_steps=10_000, + n_sae_training_tokens=1024*10_000, + train_epochs=1, + log_to_wandb=False, + wandb_log_frequency=5, ) trained_sae = toy_model_sae_runner(cfg) - assert trained_sae is not None \ No newline at end of file + assert trained_sae is not None