Skip to content

Commit

Permalink
Merge branch 'main' into sae_group_pr
Browse files Browse the repository at this point in the history
  • Loading branch information
Benw8888 authored Feb 28, 2024
2 parents d3cafa3 + ad84706 commit 082c813
Show file tree
Hide file tree
Showing 37 changed files with 177 additions and 141 deletions.
7 changes: 4 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[flake8]
ignore = E203, E266, E501, W503
extend-ignore = E203, E266, E501, W503, E721, F722, E731
max-line-length = 79
max-complexity = 10
select = E9, F63, F7, F82
max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./sae_training/geom_median/
18 changes: 8 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,16 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
run: poetry install --no-interaction
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
- name: isort linting
run: poetry run isort . --check-only --diff
- name: Run Unit Tests
run: |
make unit-test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ipython_config.py
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ This codebase contains training scripts and analysis code for Sparse AutoEncoder

## Set Up

```
conda create --name mats_sae_training python=3.11 -y
conda activate mats_sae_training
pip install -r requirements.txt
This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run:

```
poetry install
```

## Background
Expand Down
9 changes: 7 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
format:
poetry run black .
poetry run isort .

check-format:
poetry run flake8 .
poetry run black --check .
poetry run isort --check-only --diff .


test:
make unit-test
make acceptance-test

unit-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit
poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit

acceptance-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance
poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[tool.poetry]
name = "mats_sae_training"
version = "0.1.0"
description = "Training Sparse Autoencoders (SAEs)"
authors = ["Joseph Bloom"]
readme = "README.md"
packages = [{include = "sae_analysis"}, {include = "sae_training"}]

[tool.poetry.dependencies]
python = "^3.10"
transformer-lens = "^1.14.0"
transformers = "^4.38.1"
jupyter = "^1.0.0"
plotly = "^5.19.0"
plotly-express = "^0.4.1"
nbformat = "^5.9.2"
ipykernel = "^6.29.2"
matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "/~https://github.com/callummcdougall/eindex.git"}


[tool.poetry.group.dev.dependencies]
black = "^24.2.0"
pytest = "^8.0.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
flake8 = "^7.0.0"
isort = "^5.13.2"

[tool.isort]
profile = "black"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ nbformat==5.9.2
ipykernel==6.27.1
matplotlib==3.8.2
matplotlib-inline==0.1.6
pylint==3.0.2
flake8==7.0.0
isort==5.13.2
black==23.11.0
pytest==7.4.3
pytest-cov==4.1.0
Expand Down
Binary file removed sae_analysis/.DS_Store
Binary file not shown.
9 changes: 5 additions & 4 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# flake8: noqa: E402
# TODO: are these sys.path.append calls really necessary?

import sys

sys.path.append("..")
Expand All @@ -14,10 +17,10 @@
import plotly
import plotly.express as px
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader

Expand Down Expand Up @@ -148,9 +151,7 @@ def init_sae_session(self):
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)

def get_tokens(
self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6
):
def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6):
"""
Get the tokens needed for dashboard generation.
"""
Expand Down
31 changes: 13 additions & 18 deletions sae_analysis/visualizer/data_fns.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import gzip
import json
import os
import pickle
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union

import einops
import numpy as np
Expand All @@ -23,8 +20,6 @@
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

Arr = np.ndarray

from sae_analysis.visualizer.html_fns import (
CSS,
HTML_HOVERTEXT_SCRIPT,
Expand All @@ -42,6 +37,8 @@
to_str_tokens,
)

Arr = np.ndarray

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -500,7 +497,7 @@ def __init__(self):
self.x2_sum = 0
self.y2_sum = 0

def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa
def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]):
assert x.ndim == 2 and y.ndim == 2, "Both x and y should be 2D"
assert (
x.shape[-1] == y.shape[-1]
Expand All @@ -513,7 +510,7 @@ def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa
self.x2_sum += einops.reduce(x**2, "X N -> X", "sum")
self.y2_sum += einops.reduce(y**2, "Y N -> Y", "sum")

def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: # noqa
def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]:
cossim_numer = self.xy_sum
cossim_denom = torch.sqrt(torch.outer(self.x2_sum, self.y2_sum)) + 1e-6
cossim = cossim_numer / cossim_denom
Expand Down Expand Up @@ -552,7 +549,7 @@ def get_feature_data(
hook_point: str,
hook_point_layer: int,
hook_point_head_index: Optional[int],
tokens: Int[Tensor, "batch seq"], # noqa
tokens: Int[Tensor, "batch seq"],
feature_idx: Union[int, List[int]],
max_batch_size: Optional[int] = None,
left_hand_k: int = 3,
Expand Down Expand Up @@ -627,10 +624,8 @@ def get_feature_data(

# ! Define hook function to perform feature ablation

def hook_fn_act_post(
act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # noqa
): # noqa
"""
def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint):
r"""
Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where:
- f_i are the feature activations
- d_i are the feature output directions
Expand Down Expand Up @@ -666,10 +661,9 @@ def hook_fn_act_post(
# )

def hook_fn_query(
hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint # noqa
hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint
):
"""
r"""
Replace act_post with projection of query onto the resid by W_k^T.
Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where:
- f_i are the feature activations
Expand Down Expand Up @@ -701,7 +695,7 @@ def hook_fn_query(
)

def hook_fn_resid_post(
resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint # noqa
resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint
):
"""
This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation.
Expand Down Expand Up @@ -1023,7 +1017,8 @@ def hook_fn_resid_post(
save_obj, filename=filename, save_type=save_type
)
t1 = time.time()
loaded_obj = FeatureData.load_batch(
# TODO: is this doing anything? the result isn't read
FeatureData.load_batch(
filename, save_type=save_type, vocab_dict=vocab_dict
)
t2 = time.time()
Expand Down
16 changes: 8 additions & 8 deletions sae_analysis/visualizer/html_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,20 @@ def generate_tok_html(

# Make all the substitutions
html_output = re.sub(
"pos_str_(\d)",
r"pos_str_(\d)",
lambda m: pos_str[int(m.group(1))].replace(" ", " "),
html_output,
)
html_output = re.sub(
"neg_str_(\d)",
r"neg_str_(\d)",
lambda m: neg_str[int(m.group(1))].replace(" ", " "),
html_output,
)
html_output = re.sub(
"pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output
r"pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output
)
html_output = re.sub(
"neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output
r"neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output
)

# If the effect on loss is nothing (because feature isn't active), replace the HTML output with smth saying this
Expand Down Expand Up @@ -230,12 +230,12 @@ def generate_tables_html(
],
[None, "+.2f", ".1%", None, "+.2f", "+.2f", None, "+.2f", "+.2f"],
):
fn = (
lambda m: str(mylist[int(m.group(1))])
fn = lambda m: (
str(mylist[int(m.group(1))])
if myformat is None
else format(mylist[int(m.group(1))], myformat)
)
html_output = re.sub(letter + "(\d)", fn, html_output, count=3)
html_output = re.sub(letter + r"(\d)", fn, html_output, count=3)

html_output_2 = HTML_LOGIT_TABLES

Expand All @@ -258,7 +258,7 @@ def generate_tables_html(
fn = lambda m: format(mylist[int(m.group(1))], "+.2f")
elif letter == "C":
fn = lambda m: str(mylist[int(m.group(1))])
html_output_2 = re.sub(letter + "(\d)", fn, html_output_2, count=10)
html_output_2 = re.sub(letter + r"(\d)", fn, html_output_2, count=10)

return (html_output, html_output_2)

Expand Down
9 changes: 4 additions & 5 deletions sae_analysis/visualizer/model_fns.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from transformer_lens import utils
import torch
import pprint
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformer_lens import utils

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

Expand Down
13 changes: 5 additions & 8 deletions sae_analysis/visualizer/utils_fns.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import re
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import einops
import numpy as np
import torch
from eindex import eindex
from jaxtyping import Float, Int
from torch import Tensor
from transformer_lens import HookedTransformer

Arr = np.ndarray


def k_largest_indices(
x: Float[Tensor, "rows cols"], # noqa
x: Float[Tensor, "rows cols"],
k: int,
largest: bool = True,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]: # noqa
) -> Int[Tensor, "k 2"]:
"""w
Given a 2D array, returns the indices of the top or bottom `k` elements.
Expand All @@ -40,11 +37,11 @@ def sample_unique_indices(large_number, small_number):


def random_range_indices(
x: Float[Tensor, "batch seq"], # noqa
x: Float[Tensor, "batch seq"],
bounds: Tuple[float, float],
k: int,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]: # noqa
) -> Int[Tensor, "k 2"]:
"""
Given a 2D array, returns the indices of `k` elements whose values are in the range `bounds`.
Will return fewer than `k` values if there aren't enough values in the range.
Expand Down
2 changes: 1 addition & 1 deletion sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer


Expand Down Expand Up @@ -218,6 +217,7 @@ def get_buffer(self, n_batches_in_buffer):

new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0], ...] = activations


if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
else:
Expand Down
Loading

0 comments on commit 082c813

Please sign in to comment.