Skip to content

Commit

Permalink
Merge branch 'main' into activations_store_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Mar 20, 2024
2 parents cc9899c + e814054 commit 4896d0a
Show file tree
Hide file tree
Showing 41 changed files with 1,911 additions and 1,418 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ jobs:
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
- name: isort linting
run: poetry run isort . --check-only --diff
# - name: isort linting
# run: poetry run isort . --check-only --diff
- name: type checking
run: poetry run pyright
- name: Run Unit Tests
Expand Down
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: check-added-large-files
args: [--maxkb=250000]
- repo: /~https://github.com/psf/black
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
- repo: /~https://github.com/PyCQA/flake8
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ I wrote a tutorial to show users how to do some basic exploration of their SAE.

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.

![screenshot](content/dashboard_screenshot.pngdashboard_screenshot.png)
![screenshot](content/dashboard_screenshot.png)


## Example Output
Expand Down
2 changes: 2 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ check-format:
poetry run black --check .
poetry run isort --check-only --diff .

check-type:
poetry run pyright .

test:
make unit-test
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "/~https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"
babe = "^0.0.7"
nltk = "^3.8.1"


[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import plotly
import plotly.express as px
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader

Expand Down
103 changes: 103 additions & 0 deletions sae_analysis/feature_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pandas as pd
import torch
from tqdm import tqdm
from transformer_lens import HookedTransformer

from sae_training.sparse_autoencoder import SparseAutoencoder


@torch.no_grad()
def get_feature_property_df(
sparse_autoencoder: SparseAutoencoder, feature_sparsity: torch.Tensor
):
"""
feature_property_df = get_feature_property_df(sparse_autoencoder, log_feature_density.cpu())
"""

W_dec_normalized = (
sparse_autoencoder.W_dec.cpu()
) # / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True)
W_enc_normalized = (
sparse_autoencoder.W_enc.cpu()
/ sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True)
).T

d_e_projection = (W_dec_normalized * W_enc_normalized).sum(-1)
b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T

temp_df = pd.DataFrame(
{
"log_feature_sparsity": feature_sparsity + 1e-10,
"d_e_projection": d_e_projection,
# "d_e_projection_normalized": d_e_projection_normalized,
"b_enc": sparse_autoencoder.b_enc.detach().cpu(),
"b_dec_projection": b_dec_projection,
"feature": list(range(sparse_autoencoder.cfg.d_sae)), # type: ignore
"dead_neuron": (feature_sparsity < -9).cpu(),
}
)

return temp_df


@torch.no_grad()
def get_stats_df(projection: torch.Tensor):
"""
Returns a dataframe with the mean, std, skewness and kurtosis of the projection
"""
mean = projection.mean(dim=1, keepdim=True)
diffs = projection - mean
var = (diffs**2).mean(dim=1, keepdim=True)
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0), dim=1)
kurtosis = torch.mean(torch.pow(zscores, 4.0), dim=1)

stats_df = pd.DataFrame(
{
"feature": range(len(skews)),
"mean": mean.numpy().squeeze(),
"std": std.numpy().squeeze(),
"skewness": skews.numpy(),
"kurtosis": kurtosis.numpy(),
}
)

return stats_df


@torch.no_grad()
def get_all_stats_dfs(
gpt2_small_sparse_autoencoders: dict[str, SparseAutoencoder], # [hook_point, sae]
gpt2_small_sae_sparsities: dict[str, torch.Tensor], # [hook_point, sae]
model: HookedTransformer,
cosine_sim: bool = False,
):
stats_dfs = []
pbar = tqdm(gpt2_small_sparse_autoencoders.keys())
for key in pbar:
layer = int(key.split(".")[1])
sparse_autoencoder = gpt2_small_sparse_autoencoders[key]
pbar.set_description(f"Processing layer {sparse_autoencoder.cfg.hook_point}")
W_U_stats_df_dec, _ = get_W_U_W_dec_stats_df(
sparse_autoencoder.W_dec.cpu(), model, cosine_sim
)
log_feature_sparsity = gpt2_small_sae_sparsities[key].detach().cpu()
W_U_stats_df_dec["log_feature_sparsity"] = log_feature_sparsity
W_U_stats_df_dec["layer"] = layer + (1 if "post" in key else 0)
stats_dfs.append(W_U_stats_df_dec)

W_U_stats_df_dec_all_layers = pd.concat(stats_dfs, axis=0)
return W_U_stats_df_dec_all_layers


@torch.no_grad()
def get_W_U_W_dec_stats_df(
W_dec: torch.Tensor, model: HookedTransformer, cosine_sim: bool = False
) -> tuple[pd.DataFrame, torch.Tensor]:
W_U = model.W_U.detach().cpu()
if cosine_sim:
W_U = W_U / W_U.norm(dim=0, keepdim=True)
dec_projection_onto_W_U = W_dec @ W_U
W_U_stats_df = get_stats_df(dec_projection_onto_W_U)
return W_U_stats_df, dec_projection_onto_W_U
51 changes: 51 additions & 0 deletions sae_analysis/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import webbrowser

import torch
from huggingface_hub import hf_hub_download

from sae_training.sparse_autoencoder import SparseAutoencoder


def get_all_gpt2_small_saes() -> (
tuple[dict[str, SparseAutoencoder], dict[str, torch.Tensor]]
):

REPO_ID = "jbloom/GPT2-Small-SAEs"
gpt2_small_sparse_autoencoders = {}
gpt2_small_saes_log_feature_sparsities = {}
for layer in range(12):
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = SparseAutoencoder.load_from_pretrained(f"{path}")
sae.cfg.use_ghost_grads = False
gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_density = torch.load(path, map_location=sae.cfg.device)
gpt2_small_saes_log_feature_sparsities[sae.cfg.hook_point] = log_feature_density

# get the final one
layer = 11
FILENAME = (
f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_post_24576.pt"
)
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = SparseAutoencoder.load_from_pretrained(f"{path}")
sae.cfg.use_ghost_grads = False
gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_post_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_density = torch.load(path, map_location=sae.cfg.device)
gpt2_small_saes_log_feature_sparsities[sae.cfg.hook_point] = log_feature_density

return gpt2_small_sparse_autoencoders, gpt2_small_saes_log_feature_sparsities


def open_neuronpedia(feature_id: int, layer: int = 0):

path_to_html = f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feature_id}"

print(f"Feature {feature_id}")
webbrowser.open_new_tab(path_to_html)
Loading

0 comments on commit 4896d0a

Please sign in to comment.