Skip to content

Commit

Permalink
add-geometric-mean-b_dec-init
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Jan 10, 2024
1 parent 4c7f6f2 commit d5853f8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
27 changes: 27 additions & 0 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import einops
import torch
import torch.nn.functional as F
from geom_median.torch import compute_geometric_median
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
Expand Down Expand Up @@ -104,6 +105,32 @@ def forward(self, x):

return sae_out, feature_acts, loss, mse_loss, l1_loss

@torch.no_grad()
def initialize_b_dec_with_geometric_median(self, activation_store):

previous_b_dec = self.b_dec.clone()

activations_list = []
for _ in range(50): #
activations = activation_store.next_batch()
activations_list.append(activations)

all_activations = torch.concat(activations_list, dim=0)
out = compute_geometric_median(
all_activations.detach().cpu(),
skip_typechecks=True,
maxiter=100_000, per_component=True).median
out = torch.tensor(out, dtype=self.dtype, device=self.device)

previous_distances = torch.norm(all_activations - previous_b_dec.to("mps"), dim=-1)
distances = torch.norm(all_activations - out.to("mps"), dim=-1)

print("Reinitializing b_dec eometric median of activations")
print(f"Previous distances: {previous_distances.median().item()}")
print(f"New distances: {distances.mean().item()}")

self.b_dec.data = out

@torch.no_grad()
def resample_neurons_l2(
self,
Expand Down
43 changes: 22 additions & 21 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def train_sae_on_language_model(
training_steps=total_training_steps,
lr_end=sparse_autoencoder.cfg.lr / 10, # heuristic for now.
)
sparse_autoencoder.initialize_b_dec_with_geometric_median(activation_store)
sparse_autoencoder.train()


Expand All @@ -68,10 +69,10 @@ def train_sae_on_language_model(
feature_sparsity = act_freq_scores / n_frac_active_tokens

# if reset criterion is frequency in window, then then use that to generate indices.
dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]
# dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]

# if reset criterion is has_fired, then use that to generate indices.
# dead_neuron_indices = (act_freq_scores == 0).nonzero(as_tuple=False)[:, 0]
dead_neuron_indices = (act_freq_scores == 0).nonzero(as_tuple=False)[:, 0]

if len(dead_neuron_indices) > 0:
sparse_autoencoder.resample_neurons_anthropic(
Expand All @@ -92,10 +93,10 @@ def train_sae_on_language_model(

# for now, we'll hardcode this.
current_lr = scheduler.get_last_lr()[0]
reduced_lr = current_lr * 0.1
increment = (current_lr - reduced_lr) / 1000
reduced_lr = current_lr / 10_000
increment = (current_lr - reduced_lr) / 10_000
optimizer.param_groups[0]['lr'] = reduced_lr
steps_before_reset = 1000
steps_before_reset = 10_000

# Resample dead neurons
if (feature_sampling_method == "l2") and ((n_training_steps + 1) % dead_feature_window == 0):
Expand All @@ -106,7 +107,7 @@ def train_sae_on_language_model(

# # if standard resampling <- do this
# n_resampled_neurons = sparse_autoencoder.resample_neurons(
# activation_store.next_batch(),
# activationcuda_store.next_batch(),
# feature_sparsity,
# feature_reinit_scale,
# optimizer
Expand Down Expand Up @@ -210,23 +211,23 @@ def train_sae_on_language_model(
if use_wandb and ((n_training_steps + 1) % (wandb_log_frequency * 10) == 0):
run_evals(sparse_autoencoder, activation_store, model, n_training_steps)

log_feature_sparsity = torch.log10(feature_sparsity + 1e-10).detach().cpu()
# log_feature_sparsity = torch.log10(feature_sparsity + 1e-10).detach().cpu()

# sparsity_line_chart = px.scatter(
# y = log_feature_sparsity,
# title="Feature Sparsity",
# labels={"y": "log10(sparsity)", "x": "FeatureID"},
# range_y=[-8, 0],
# marginal_y="histogram",
# # sparsity_line_chart = px.scatter(
# # y = log_feature_sparsity,
# # title="Feature Sparsity",
# # labels={"y": "log10(sparsity)", "x": "FeatureID"},
# # range_y=[-8, 0],
# # marginal_y="histogram",
# # )
# wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy())
# wandb.log(
# {
# "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
# "plots/feature_density_line_chart": wandb_histogram,
# },
# step=n_training_steps,
# )
wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy())
wandb.log(
{
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
},
step=n_training_steps,
)



Expand Down

0 comments on commit d5853f8

Please sign in to comment.