Skip to content

Commit

Permalink
minor refactoring to SAE and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Mar 28, 2024
1 parent 277f35b commit 92a98dd
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 237 deletions.
163 changes: 62 additions & 101 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import gzip
import os
import pickle
from typing import Any, NamedTuple
from typing import NamedTuple

import einops
import torch
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.geometric_median import compute_geometric_median


class ForwardOutput(NamedTuple):
Expand Down Expand Up @@ -64,7 +63,7 @@ def __init__(

with torch.no_grad():
# Anthropic normalize this to have unit columns
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
self.set_decoder_norm_to_unit_norm()

self.b_dec = nn.Parameter(
torch.zeros(self.d_in, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -104,96 +103,42 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
)

# add config for whether l2 is normalized:
x_centred = x - x.mean(dim=0, keepdim=True)
mse_loss = (
torch.pow((sae_out - x.float()), 2)
/ (x_centred**2).sum(dim=-1, keepdim=True).sqrt()
)

mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device)
per_item_mse_loss = _per_item_mse_loss_with_target_norm(sae_out, x)
ghost_grad_loss = torch.tensor(0.0, dtype=self.dtype, device=self.device)
# gate on config and training so evals is not slowed down.
if (
self.cfg.use_ghost_grads
and self.training
and dead_neuron_mask is not None
and dead_neuron_mask.sum() > 0
):
# ghost protocol

# 1.
residual = x - sae_out
residual_centred = residual - residual.mean(dim=0, keepdim=True)
l2_norm_residual = torch.norm(residual, dim=-1)

# 2.
feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
ghost_out = ghost_out * norm_scaling_factor[:, None].detach()

# 3.
mse_loss_ghost_resid = (
torch.pow((ghost_out - residual.detach().float()), 2)
/ (residual_centred.detach() ** 2).sum(dim=-1, keepdim=True).sqrt()
ghost_grad_loss = self.calculate_ghost_grad_loss(
x=x,
sae_out=sae_out,
per_item_mse_loss=per_item_mse_loss,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach()
mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid

mse_loss_ghost_resid = mse_loss_ghost_resid.mean()

mse_loss = mse_loss.mean()
mse_loss = per_item_mse_loss.mean()
sparsity = feature_acts.norm(p=self.lp_norm, dim=1).mean(dim=(0,))
l1_loss = self.l1_coefficient * sparsity
loss = mse_loss + l1_loss + mse_loss_ghost_resid
loss = mse_loss + l1_loss + ghost_grad_loss

return ForwardOutput(
sae_out=sae_out,
feature_acts=feature_acts,
loss=loss,
mse_loss=mse_loss,
l1_loss=l1_loss,
ghost_grad_loss=mse_loss_ghost_resid,
ghost_grad_loss=ghost_grad_loss,
)

@torch.no_grad()
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
self.b_dec.data = out

@torch.no_grad()
def initialize_b_dec(self, all_activations: torch.Tensor):
if self.cfg.b_dec_init_method == "geometric_median":
self.initialize_b_dec_with_geometric_median(all_activations)
elif self.cfg.b_dec_init_method == "mean":
self.initialize_b_dec_with_mean(all_activations)
elif self.cfg.b_dec_init_method == "zeros":
pass
else:
raise ValueError(
f"Unexpected b_dec_init_method: {self.cfg.b_dec_init_method}"
)

@torch.no_grad()
def initialize_b_dec_with_geometric_median(self, all_activations: torch.Tensor):
previous_b_dec = self.b_dec.clone().cpu()
out = compute_geometric_median(
all_activations,
maxiter=100,
).median

previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
distances = torch.norm(all_activations - out, dim=-1)

print("Reinitializing b_dec with geometric median of activations")
print(
f"Previous distances: {previous_distances.median(0).values.mean().item()}"
)
print(f"New distances: {distances.median(0).values.mean().item()}")

out = torch.tensor(out, dtype=self.dtype, device=self.device)
self.b_dec.data = out

@torch.no_grad()
def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
previous_b_dec = self.b_dec.clone().cpu()
Expand All @@ -210,37 +155,6 @@ def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):

self.b_dec.data = out.to(self.dtype).to(self.device)

@torch.no_grad()
def get_test_loss(self, batch_tokens: torch.Tensor, model: HookedRootModule):
"""
A method for running the model with the SAE activations in order to return the loss.
returns per token loss when activations are substituted in.
"""
head_index = self.cfg.hook_point_head_index

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
activations = self.forward(activations)[0].to(activations.dtype)
return activations

def head_replacement_hook(activations: torch.Tensor, hook: Any):
new_actions = self.forward(activations[:, :, head_index])[0].to(
activations.dtype
)
activations[:, :, head_index] = new_actions
return activations

replacement_hook = (
standard_replacement_hook if head_index is None else head_replacement_hook
)

ce_loss_with_recons = model.run_with_hooks(
batch_tokens,
return_type="loss",
fwd_hooks=[(self.cfg.hook_point, replacement_hook)],
)

return ce_loss_with_recons

@torch.no_grad()
def set_decoder_norm_to_unit_norm(self):
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
Expand All @@ -251,14 +165,13 @@ def remove_gradient_parallel_to_decoder_directions(self):
Update grads so that they remove the parallel component
(d_sae, d_in) shape
"""
assert self.W_dec.grad is not None # keep pyright happy

parallel_component = einops.einsum(
self.W_dec.grad,
self.W_dec.data,
"d_sae d_in, d_sae d_in -> d_sae",
)
assert parallel_component is not None # keep pyright happy

self.W_dec.grad -= einops.einsum(
parallel_component,
self.W_dec.data,
Expand All @@ -278,6 +191,9 @@ def save_model(self, path: str):

if path.endswith(".pt"):
torch.save(state_dict, path)
elif path.endswith(".pkl"):
with open(path, "wb") as f:
pickle.dump(state_dict, f)
elif path.endswith("pkl.gz"):
with gzip.open(path, "wb") as f:
pickle.dump(state_dict, f)
Expand Down Expand Up @@ -344,3 +260,48 @@ def load_from_pretrained(cls, path: str):
def get_name(self):
sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
return sae_name

def calculate_ghost_grad_loss(
self,
x: torch.Tensor,
sae_out: torch.Tensor,
per_item_mse_loss: torch.Tensor,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
) -> torch.Tensor:
# 1.
residual = x - sae_out
l2_norm_residual = torch.norm(residual, dim=-1)

# 2.
feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
ghost_out = ghost_out * norm_scaling_factor[:, None].detach()

# 3.
per_item_mse_loss_ghost_resid = _per_item_mse_loss_with_target_norm(
ghost_out, residual.detach()
)
mse_rescaling_factor = (
per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
).detach()
per_item_mse_loss_ghost_resid = (
mse_rescaling_factor * per_item_mse_loss_ghost_resid
)

return per_item_mse_loss_ghost_resid.mean()


def _per_item_mse_loss_with_target_norm(
preds: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Calculate MSE loss per item in the batch, without taking a mean.
Then, normalizes by the L2 norm of the centered target.
This normalization seems to improve performance.
"""
target_centered = target - target.mean(dim=0, keepdim=True)
normalization = target_centered.norm(dim=-1, keepdim=True)
return torch.nn.functional.mse_loss(preds, target, reduction="none") / normalization
Loading

0 comments on commit 92a98dd

Please sign in to comment.