diff --git a/.flake8 b/.flake8 index e7169d49..b613489b 100644 --- a/.flake8 +++ b/.flake8 @@ -5,4 +5,4 @@ max-complexity = 25 extend-select = E9, F63, F7, F82 show-source = true statistics = true -exclude = ./sae_training/geom_median/ +exclude = ./sae_training/geom_median/, ./wandb/*, ./research/wandb/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c45c4453..d56713f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,19 +7,19 @@ repos: - id: check-added-large-files args: [--maxkb=250000] - repo: /~https://github.com/psf/black - rev: 23.3.0 + rev: 24.2.0 hooks: - id: black - repo: /~https://github.com/PyCQA/flake8 rev: 6.0.0 hooks: - id: flake8 + args: ['--config=.flake8'] additional_dependencies: [ 'flake8-blind-except', - 'flake8-docstrings', + # 'flake8-docstrings', 'flake8-bugbear', 'flake8-comprehensions', - 'flake8-docstrings', 'flake8-implicit-str-concat', 'pydocstyle>=5.0.0', ] diff --git a/pyproject.toml b/pyproject.toml index f02784f8..e1c5caea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ ipykernel = "^6.29.2" matplotlib = "^3.8.3" matplotlib-inline = "^0.1.6" eindex = {git = "/~https://github.com/callummcdougall/eindex.git"} +datasets = "^2.17.1" [tool.poetry.group.dev.dependencies] @@ -33,4 +34,4 @@ profile = "black" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +build-backend = "poetry.core.masonry.api" diff --git a/sae_analysis/dashboard_runner.py b/sae_analysis/dashboard_runner.py index 4a3815ab..a12e46d3 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_analysis/dashboard_runner.py @@ -17,10 +17,10 @@ import plotly import plotly.express as px import torch -import wandb from torch.nn.functional import cosine_similarity from tqdm import tqdm +import wandb from sae_analysis.visualizer.data_fns import get_feature_data from sae_training.utils import LMSparseAutoencoderSessionloader diff --git a/sae_training/config.py b/sae_training/config.py index 570c78bf..f417cac8 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -3,6 +3,7 @@ from typing import Optional import torch + import wandb @@ -21,9 +22,9 @@ class RunnerConfig(ABC): is_dataset_tokenized: bool = True context_size: int = 128 use_cached_activations: bool = False - cached_activations_path: Optional[ - str - ] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" + cached_activations_path: Optional[str] = ( + None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" + ) # SAE Parameters d_in: int = 512 @@ -61,7 +62,9 @@ class LanguageModelSAERunnerConfig(RunnerConfig): l1_coefficient: float = 1e-3 lp_norm: float = 1 lr: float = 3e-4 - lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + lr_scheduler_name: str = ( + "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + ) lr_warm_up_steps: int = 500 train_batch_size: int = 4096 diff --git a/sae_training/evals.py b/sae_training/evals.py index 3603a6c4..58ecd273 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -2,11 +2,11 @@ import pandas as pd import torch -import wandb from tqdm import tqdm from transformer_lens import HookedTransformer from transformer_lens.utils import get_act_name +import wandb from sae_training.activations_store import ActivationsStore from sae_training.sparse_autoencoder import SparseAutoencoder diff --git a/sae_training/sae_group.py b/sae_training/sae_group.py index 7197b682..61a0713d 100644 --- a/sae_training/sae_group.py +++ b/sae_training/sae_group.py @@ -1,10 +1,12 @@ +import dataclasses import gzip import os import pickle -import dataclasses -from sae_training.sparse_autoencoder import SparseAutoencoder + import torch +from sae_training.sparse_autoencoder import SparseAutoencoder + class SAEGroup: def __init__(self, cfg): diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py index 67aa267a..9fc5c3db 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_training/toy_model_runner.py @@ -2,8 +2,8 @@ import einops import torch -import wandb +import wandb from sae_training.sparse_autoencoder import SparseAutoencoder from sae_training.toy_models import Config as ToyConfig from sae_training.toy_models import Model as ToyModel diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 854aa805..a4056b2c 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -1,14 +1,14 @@ import torch -import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens import HookedTransformer +import wandb from sae_training.activations_store import ActivationsStore from sae_training.evals import run_evals +from sae_training.geom_median.src.geom_median.torch import compute_geometric_median from sae_training.optim import get_scheduler from sae_training.sae_group import SAEGroup -from sae_training.geom_median.src.geom_median.torch import compute_geometric_median def train_sae_on_language_model( diff --git a/sae_training/train_sae_on_toy_model.py b/sae_training/train_sae_on_toy_model.py index c43b75c4..f5008713 100644 --- a/sae_training/train_sae_on_toy_model.py +++ b/sae_training/train_sae_on_toy_model.py @@ -1,8 +1,8 @@ import torch -import wandb from torch.utils.data import DataLoader from tqdm import tqdm +import wandb from sae_training.sparse_autoencoder import SparseAutoencoder diff --git a/scripts/generate_dashboards.py b/scripts/generate_dashboards.py index bbfa0af7..40838392 100644 --- a/scripts/generate_dashboards.py +++ b/scripts/generate_dashboards.py @@ -16,10 +16,10 @@ import plotly import plotly.express as px import torch -import wandb from torch.nn.functional import cosine_similarity from tqdm import tqdm +import wandb from sae_analysis.visualizer.data_fns import get_feature_data from sae_training.utils import LMSparseAutoencoderSessionloader diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 43515d3d..624e7924 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -37,8 +37,9 @@ ], "source": [ "import torch\n", - "import os \n", + "import os\n", "import sys\n", + "\n", "sys.path.append(\"..\")\n", "\n", "from sae_training.config import LanguageModelSAERunnerConfig\n", @@ -77,55 +78,46 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"gelu-2l\",\n", - " hook_point = \"blocks.0.hook_mlp_out\",\n", - " hook_point_layer = 0,\n", - " d_in = 512,\n", - " dataset_path = \"NeelNanda/c4-tokenized-2b\",\n", + " model_name=\"gelu-2l\",\n", + " hook_point=\"blocks.0.hook_mlp_out\",\n", + " hook_point_layer=0,\n", + " d_in=512,\n", + " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", " is_dataset_tokenized=True,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 32,\n", - " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", - " \n", + " expansion_factor=32,\n", + " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", " # Training Parameters\n", - " lr = 0.0012,\n", + " lr=0.0012,\n", " lr_scheduler_name=\"constantwithwarmup\",\n", - " l1_coefficient = 0.00016,\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", - " \n", + " l1_coefficient=0.00016,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 128,\n", - " total_training_tokens = 1_000_000 * 100, \n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 100,\n", + " store_batch_size=32,\n", " # Resampling protocol\n", " use_ghost_grads=True,\n", - " feature_sampling_window = 5000,\n", + " feature_sampling_window=5000,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold = 1e-4,\n", - " \n", + " dead_feature_threshold=1e-4,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_language_models_gelu_2l_test\",\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", " wandb_log_frequency=10,\n", - " \n", " # Misc\n", - " device = device,\n", - " seed = 42,\n", - " n_checkpoints = 0,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", "\n", "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -453,53 +445,46 @@ "\n", "layer = 3\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"gpt2-small\",\n", - " hook_point = f\"blocks.{layer}.hook_resid_pre\",\n", - " hook_point_layer = layer,\n", - " d_in = 768,\n", - " dataset_path = \"Skylion007/openwebtext\",\n", + " model_name=\"gpt2-small\",\n", + " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", + " hook_point_layer=layer,\n", + " d_in=768,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", " is_dataset_tokenized=False,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 32, # determines the dimension of the SAE.\n", - " b_dec_init_method = \"mean\", # geometric median is better but slower to get started\n", - " \n", + " expansion_factor=32, # determines the dimension of the SAE.\n", + " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", " # Training Parameters\n", - " lr = 0.0004,\n", - " l1_coefficient = 0.00008,\n", + " lr=0.0004,\n", + " l1_coefficient=0.00008,\n", " lr_scheduler_name=\"constantwithwarmup\",\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", " lr_warm_up_steps=5000,\n", - " \n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 128,\n", - " total_training_tokens = 1_000_000 * 300, # 200M tokens seems doable overnight.\n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", + " store_batch_size=32,\n", " # Resampling protocol\n", " use_ghost_grads=True,\n", - " feature_sampling_window = 2500,\n", + " feature_sampling_window=2500,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold = 1e-8,\n", - " \n", + " dead_feature_threshold=1e-8,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_language_models_resid_pre_test\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=100,\n", - " \n", " # Misc\n", - " device = \"cuda\",\n", - " seed = 42,\n", - " n_checkpoints = 10,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + " device=\"cuda\",\n", + " seed=42,\n", + " n_checkpoints=10,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -516,8 +501,8 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", - "import sys \n", + "import os\n", + "import sys\n", "\n", "sys.path.append(\"..\")\n", "from sae_training.config import LanguageModelSAERunnerConfig\n", @@ -528,54 +513,47 @@ "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"pythia-70m-deduped\",\n", - " hook_point = \"blocks.0.hook_mlp_out\",\n", - " hook_point_layer = 0,\n", - " d_in = 512,\n", - " dataset_path = \"EleutherAI/the_pile_deduplicated\",\n", + " model_name=\"pythia-70m-deduped\",\n", + " hook_point=\"blocks.0.hook_mlp_out\",\n", + " hook_point_layer=0,\n", + " d_in=512,\n", + " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", " is_dataset_tokenized=False,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 64,\n", - " \n", + " expansion_factor=64,\n", " # Training Parameters\n", - " lr = 3e-4,\n", - " l1_coefficient = 4e-5,\n", - " train_batch_size = 8192,\n", - " context_size = 128,\n", + " lr=3e-4,\n", + " l1_coefficient=4e-5,\n", + " train_batch_size=8192,\n", + " context_size=128,\n", " lr_scheduler_name=\"constantwithwarmup\",\n", " lr_warm_up_steps=10_000,\n", - " \n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 64,\n", - " total_training_tokens = 1_000_000 * 800, \n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=64,\n", + " total_training_tokens=1_000_000 * 800,\n", + " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method = 'anthropic',\n", - " feature_sampling_window = 2000, # Doesn't currently matter.\n", - " feature_reinit_scale = 0.2,\n", + " feature_sampling_method=\"anthropic\",\n", + " feature_sampling_window=2000, # Doesn't currently matter.\n", + " feature_reinit_scale=0.2,\n", " dead_feature_window=40000,\n", - " dead_feature_threshold = 1e-8,\n", - " \n", + " dead_feature_threshold=1e-8,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=20,\n", - " \n", " # Misc\n", - " device = \"cuda\",\n", - " seed = 42,\n", - " n_checkpoints = 0,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", + " device=\"cuda\",\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", "\n", "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -592,8 +570,9 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", + "import os\n", "import sys\n", + "\n", "sys.path.append(\"../\")\n", "\n", "from sae_training.config import LanguageModelSAERunnerConfig\n", @@ -601,55 +580,48 @@ "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"pythia-70m-deduped\",\n", - " hook_point = \"blocks.2.attn.hook_q\",\n", - " hook_point_layer = 2,\n", + " model_name=\"pythia-70m-deduped\",\n", + " hook_point=\"blocks.2.attn.hook_q\",\n", + " hook_point_layer=2,\n", " hook_point_head_index=7,\n", - " d_in = 64,\n", - " dataset_path = \"EleutherAI/the_pile_deduplicated\",\n", + " d_in=64,\n", + " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", " is_dataset_tokenized=False,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 16,\n", - " \n", + " expansion_factor=16,\n", " # Training Parameters\n", - " lr = 0.0012,\n", - " l1_coefficient = 0.003,\n", + " lr=0.0012,\n", + " l1_coefficient=0.003,\n", " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=1000, # about 4 million tokens.\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", - " \n", + " lr_warm_up_steps=1000, # about 4 million tokens.\n", + " train_batch_size=4096,\n", + " context_size=128,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 128,\n", - " total_training_tokens = 1_000_000 * 1500,\n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 1500,\n", + " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method = 'anthropic',\n", - " feature_sampling_window = 1000,# doesn't do anything currently.\n", - " feature_reinit_scale = 0.2,\n", + " feature_sampling_method=\"anthropic\",\n", + " feature_sampling_window=1000, # doesn't do anything currently.\n", + " feature_reinit_scale=0.2,\n", " resample_batches=8,\n", " dead_feature_window=60000,\n", - " dead_feature_threshold = 1e-5,\n", - " \n", + " dead_feature_threshold=1e-5,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=100,\n", - " \n", " # Misc\n", - " device = \"mps\",\n", - " seed = 42,\n", - " n_checkpoints = 15,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=15,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -666,60 +638,52 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", + "import os\n", "\n", "from sae_training.config import LanguageModelSAERunnerConfig\n", "from sae_training.lm_runner import language_model_sae_runner\n", "\n", "\n", - "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"tiny-stories-2L-33M\",\n", - " hook_point = \"blocks.1.mlp.hook_post\",\n", - " hook_point_layer = 1,\n", - " d_in = 4096,\n", - " dataset_path = \"roneneldan/TinyStories\",\n", + " model_name=\"tiny-stories-2L-33M\",\n", + " hook_point=\"blocks.1.mlp.hook_post\",\n", + " hook_point_layer=1,\n", + " d_in=4096,\n", + " dataset_path=\"roneneldan/TinyStories\",\n", " is_dataset_tokenized=False,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 4,\n", - " \n", + " expansion_factor=4,\n", " # Training Parameters\n", - " lr = 1e-4,\n", - " l1_coefficient = 3e-4,\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", - " \n", + " lr=1e-4,\n", + " l1_coefficient=3e-4,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 128,\n", - " total_training_tokens = 1_000_000 * 10, # want 500M eventually.\n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 10, # want 500M eventually.\n", + " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method = 'l2',\n", - " feature_sampling_window = 2500, # Doesn't currently matter.\n", - " feature_reinit_scale = 0.2,\n", + " feature_sampling_method=\"l2\",\n", + " feature_sampling_window=2500, # Doesn't currently matter.\n", + " feature_reinit_scale=0.2,\n", " dead_feature_window=1250,\n", - " dead_feature_threshold = 0.0005,\n", - " \n", + " dead_feature_threshold=0.0005,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=10,\n", - " \n", " # Misc\n", - " device = \"mps\",\n", - " seed = 42,\n", - " n_checkpoints = 0,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -735,12 +699,10 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner\n", "\n", "\n", "cfg = SAEToyModelRunnerConfig(\n", - " \n", " # Model Details\n", " n_features=200,\n", " n_hidden=5,\n", @@ -748,18 +710,15 @@ " n_anticorrelated_pairs=0,\n", " feature_probability=0.025,\n", " model_training_steps=10_000,\n", - " \n", " # SAE Parameters\n", " d_sae=240,\n", " l1_coefficient=0.001,\n", - " \n", " # SAE Train Config\n", " train_batch_size=1028,\n", " feature_sampling_window=3_000,\n", " dead_feature_window=1_000,\n", " feature_reinit_scale=0.5,\n", - " total_training_tokens=4096*300,\n", - " \n", + " total_training_tokens=4096 * 300,\n", " # Other parameters\n", " log_to_wandb=True,\n", " wandb_project=\"sae-training-test\",\n", @@ -769,7 +728,7 @@ "\n", "trained_sae = toy_model_sae_runner(cfg)\n", "\n", - "assert trained_sae is not None\n" + "assert trained_sae is not None" ] }, { @@ -793,8 +752,9 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", + "import os\n", "import sys\n", + "\n", "sys.path.append(\"..\")\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", @@ -803,32 +763,28 @@ "from sae_training.cache_activations_runner import cache_activations_runner\n", "\n", "cfg = CacheActivationsRunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"gpt2-small\",\n", - " hook_point = \"blocks.10.attn.hook_q\",\n", - " hook_point_layer = 10,\n", + " model_name=\"gpt2-small\",\n", + " hook_point=\"blocks.10.attn.hook_q\",\n", + " hook_point_layer=10,\n", " hook_point_head_index=7,\n", - " d_in = 64,\n", - " dataset_path = \"Skylion007/openwebtext\",\n", + " d_in=64,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", " is_dataset_tokenized=False,\n", " cached_activations_path=\"../activations/\",\n", - " \n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 16,\n", - " total_training_tokens = 500_000_000, \n", - " store_batch_size = 32,\n", - "\n", + " n_batches_in_buffer=16,\n", + " total_training_tokens=500_000_000,\n", + " store_batch_size=32,\n", " # Activation caching shuffle parameters\n", - " n_shuffles_final = 16,\n", - " \n", + " n_shuffles_final=16,\n", " # Misc\n", - " device = \"mps\",\n", - " seed = 42,\n", - " dtype = torch.float32,\n", - " )\n", + " device=\"mps\",\n", + " seed=42,\n", + " dtype=torch.float32,\n", + ")\n", "\n", - "cache_activations_runner(cfg)\n" + "cache_activations_runner(cfg)" ] }, { @@ -846,60 +802,54 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", + "import os\n", + "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", "from sae_training.config import LanguageModelSAERunnerConfig\n", "from sae_training.lm_runner import language_model_sae_runner\n", "\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"gpt2-small\",\n", - " hook_point = \"blocks.10.hook_resid_pre\",\n", - " hook_point_layer = 11,\n", - " d_in = 768,\n", - " dataset_path = \"Skylion007/openwebtext\",\n", + " model_name=\"gpt2-small\",\n", + " hook_point=\"blocks.10.hook_resid_pre\",\n", + " hook_point_layer=11,\n", + " d_in=768,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", " is_dataset_tokenized=False,\n", " use_cached_activations=True,\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 64, # determines the dimension of the SAE.\n", - " \n", + " expansion_factor=64, # determines the dimension of the SAE.\n", " # Training Parameters\n", - " lr = 1e-5,\n", - " l1_coefficient = 5e-4,\n", + " lr=1e-5,\n", + " l1_coefficient=5e-4,\n", " lr_scheduler_name=None,\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", - " \n", + " train_batch_size=4096,\n", + " context_size=128,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 64,\n", - " total_training_tokens = 200_000,\n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=64,\n", + " total_training_tokens=200_000,\n", + " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method = 'l2',\n", - " feature_sampling_window = 1000,\n", - " feature_reinit_scale = 0.2,\n", + " feature_sampling_method=\"l2\",\n", + " feature_sampling_window=1000,\n", + " feature_reinit_scale=0.2,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold = 1e-7,\n", - " \n", + " dead_feature_threshold=1e-7,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_gpt2_small\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_gpt2_small\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=50,\n", - " \n", " # Misc\n", - " device = \"mps\",\n", - " seed = 42,\n", - " n_checkpoints = 5,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)\n" + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=5,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] }, { @@ -916,8 +866,9 @@ "outputs": [], "source": [ "import torch\n", - "import os \n", - "import sys \n", + "import os\n", + "import sys\n", + "\n", "sys.path.append(\"..\")\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", @@ -925,57 +876,49 @@ "from sae_training.lm_runner import language_model_sae_runner\n", "\n", "\n", - "\n", "# for l1_coefficient in [9e-4,8e-4,7e-4]:\n", "cfg = LanguageModelSAERunnerConfig(\n", - "\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name = \"gpt2-small\",\n", - " hook_point = \"blocks.10.attn.hook_q\",\n", - " hook_point_layer = 10,\n", + " model_name=\"gpt2-small\",\n", + " hook_point=\"blocks.10.attn.hook_q\",\n", + " hook_point_layer=10,\n", " hook_point_head_index=7,\n", - " d_in = 64,\n", - " dataset_path = \"Skylion007/openwebtext\",\n", + " d_in=64,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", " is_dataset_tokenized=False,\n", " use_cached_activations=True,\n", " cached_activations_path=\"../activations/\",\n", - " \n", " # SAE Parameters\n", - " expansion_factor = 64, # determines the dimension of the SAE. (64*64 = 4096, 64*4*64 = 32768)\n", - " \n", + " expansion_factor=64, # determines the dimension of the SAE. (64*64 = 4096, 64*4*64 = 32768)\n", " # Training Parameters\n", - " lr = 1e-3,\n", - " l1_coefficient = 2e-4,\n", + " lr=1e-3,\n", + " l1_coefficient=2e-4,\n", " # lr_scheduler_name=\"LinearWarmupDecay\",\n", " lr_warm_up_steps=2200,\n", - " train_batch_size = 4096,\n", - " context_size = 128,\n", - " \n", + " train_batch_size=4096,\n", + " context_size=128,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer = 512,\n", - " total_training_tokens = 3_000_000,\n", - " store_batch_size = 32,\n", - " \n", + " n_batches_in_buffer=512,\n", + " total_training_tokens=3_000_000,\n", + " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method = 'l2',\n", - " feature_sampling_window = 1000,\n", - " feature_reinit_scale = 0.2,\n", + " feature_sampling_method=\"l2\",\n", + " feature_sampling_window=1000,\n", + " feature_reinit_scale=0.2,\n", " dead_feature_window=200,\n", - " dead_feature_threshold = 5e-6,\n", - " \n", + " dead_feature_threshold=5e-6,\n", " # WANDB\n", - " log_to_wandb = True,\n", - " wandb_project= \"mats_sae_training_gpt2_small_hook_q_dev\",\n", - " wandb_entity = None,\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_gpt2_small_hook_q_dev\",\n", + " wandb_entity=None,\n", " wandb_log_frequency=5,\n", - " \n", " # Misc\n", - " device = \"mps\",\n", - " seed = 42,\n", - " n_checkpoints = 0,\n", - " checkpoint_path = \"checkpoints\",\n", - " dtype = torch.float32,\n", - " )\n", + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", "\n", "# cfg.d_sae\n", "sparse_autoencoder = language_model_sae_runner(cfg)\n", diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py index f6a6b68c..bb1caf92 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/test_activations_store.py @@ -33,6 +33,7 @@ def cfg(): mock_config.context_size = 16 mock_config.use_cached_activations = False mock_config.hook_point_head_index = None + mock_config.lp_norm = 1 mock_config.feature_sampling_method = None mock_config.feature_sampling_window = 50 diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py index e30b2103..b68a6acb 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/test_sparse_autoencoder.py @@ -32,6 +32,7 @@ def cfg(): mock_config.expansion_factor = 2 mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor mock_config.l1_coefficient = 2e-3 + mock_config.lp_norm = 1 mock_config.lr = 2e-4 mock_config.train_batch_size = 2048 mock_config.context_size = 64 diff --git a/tutorials/evaluating_your_sae.ipynb b/tutorials/evaluating_your_sae.ipynb index df9ef387..7f71efc8 100644 --- a/tutorials/evaluating_your_sae.ipynb +++ b/tutorials/evaluating_your_sae.ipynb @@ -27,7 +27,7 @@ "import plotly.express as px\n", "from transformer_lens import utils\n", "from datasets import load_dataset\n", - "from typing import Dict\n", + "from typing import Dict\n", "from pathlib import Path\n", "\n", "from functools import partial\n", @@ -38,7 +38,7 @@ "from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData\n", "\n", "if torch.backends.mps.is_available():\n", - " device = \"mps\" \n", + " device = \"mps\"\n", "else:\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", @@ -61,10 +61,11 @@ "source": [ "# Start by downloading them from huggingface\n", "from huggingface_hub import hf_hub_download\n", + "\n", "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", "\n", "\n", - "layer = 8 # any layer from 0 - 11 works here\n", + "layer = 8 # any layer from 0 - 11 works here\n", "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"\n", "\n", "# this is great because if you've already downloaded the SAE it won't download it twice!\n", @@ -78,8 +79,8 @@ "outputs": [], "source": [ "# We can then load the SAE, dataset and model using the session loader\n", - "model, sparse_autoencoders, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(\n", - " path = path\n", + "model, sparse_autoencoders, activation_store = (\n", + " LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)\n", ")" ] }, @@ -91,7 +92,9 @@ "source": [ "for i, sae in enumerate(sparse_autoencoders):\n", " hyp = sae.cfg\n", - " print(f\"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}\")" + " print(\n", + " f\"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}\"\n", + " )" ] }, { @@ -124,7 +127,7 @@ "metadata": {}, "outputs": [], "source": [ - "sparse_autoencoder.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", + "sparse_autoencoder.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", "with torch.no_grad():\n", " batch_tokens = activation_store.get_batch_tokens()\n", " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", @@ -132,11 +135,11 @@ " cache[sparse_autoencoder.cfg.hook_point]\n", " )\n", " del cache\n", - " \n", + "\n", " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", - " l0 = (feature_acts[:,1:] > 0).float().sum(-1).detach()\n", + " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", " print(\"average l0\", l0.mean().item())\n", - " px.histogram(l0.flatten().cpu().numpy()).show()\n" + " px.histogram(l0.flatten().cpu().numpy()).show()" ] }, { @@ -149,9 +152,11 @@ "def reconstr_hook(activation, hook, sae_out):\n", " return sae_out\n", "\n", + "\n", "def zero_abl_hook(activation, hook):\n", " return torch.zeros_like(activation)\n", "\n", + "\n", "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", "print(\n", " \"reconstr\",\n", @@ -160,7 +165,7 @@ " fwd_hooks=[\n", " (\n", " utils.get_act_name(\"resid_pre\", 10),\n", - " partial(reconstr_hook, sae_out= sae_out),\n", + " partial(reconstr_hook, sae_out=sae_out),\n", " )\n", " ],\n", " return_type=\"loss\",\n", @@ -197,10 +202,11 @@ "\n", "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", "tokens = model.to_tokens(example_prompt)\n", - "sae_out, feature_acts, loss, mse_loss, l1_loss,_ = sparse_autoencoder(\n", + "sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(\n", " cache[sparse_autoencoder.cfg.hook_point]\n", ")\n", "\n", + "\n", "def reconstr_hook(activations, hook, sae_out):\n", " return sae_out\n", "\n", @@ -208,6 +214,7 @@ "def zero_abl_hook(mlp_out, hook):\n", " return torch.zeros_like(mlp_out)\n", "\n", + "\n", "hook_point = sparse_autoencoder.cfg.hook_point\n", "\n", "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", @@ -258,7 +265,7 @@ "metadata": {}, "outputs": [], "source": [ - "vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)\n", + "vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)\n", "px.bar(x=[str(i) for i in inds], y=vals).show()" ] }, @@ -269,16 +276,20 @@ "outputs": [], "source": [ "vocab_dict = model.tokenizer.vocab\n", - "vocab_dict = {v: k.replace(\"Ġ\", \" \").replace(\"\\n\", \"\\\\n\") for k, v in vocab_dict.items()}\n", + "vocab_dict = {\n", + " v: k.replace(\"Ġ\", \" \").replace(\"\\n\", \"\\\\n\") for k, v in vocab_dict.items()\n", + "}\n", "\n", "vocab_dict_filepath = Path(os.getcwd()) / \"vocab_dict.json\"\n", "if not vocab_dict_filepath.exists():\n", " with open(vocab_dict_filepath, \"w\") as f:\n", " json.dump(vocab_dict, f)\n", - " \n", + "\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "data = load_dataset(\"NeelNanda/c4-code-20k\", split=\"train\") # currently use this dataset to avoid deal with tokenization while streaming\n", + "data = load_dataset(\n", + " \"NeelNanda/c4-code-20k\", split=\"train\"\n", + ") # currently use this dataset to avoid deal with tokenization while streaming\n", "tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)\n", "tokenized_data = tokenized_data.shuffle(42)\n", "all_tokens = tokenized_data[\"tokens\"]\n", @@ -288,7 +299,7 @@ "# make the entire sequence indexing parallelized, but that's possibly not worth it right now.\n", "\n", "max_batch_size = 512\n", - "total_batch_size = 4096*5\n", + "total_batch_size = 4096 * 5\n", "feature_idx = list(inds.flatten().cpu().numpy())\n", "# max_batch_size = 512\n", "# total_batch_size = 16384\n", @@ -305,12 +316,12 @@ " tokens=tokens,\n", " feature_idx=feature_idx,\n", " max_batch_size=max_batch_size,\n", - " left_hand_k = 3,\n", - " buffer = (5, 5),\n", - " n_groups = 10,\n", - " first_group_size = 20,\n", - " other_groups_size = 5,\n", - " verbose = True,\n", + " left_hand_k=3,\n", + " buffer=(5, 5),\n", + " n_groups=10,\n", + " first_group_size=20,\n", + " other_groups_size=5,\n", + " verbose=True,\n", ")\n", "\n", "\n", diff --git a/upload_to_huggingface.ipynb b/upload_to_huggingface.ipynb index 4ce71bef..1e53d142 100644 --- a/upload_to_huggingface.ipynb +++ b/upload_to_huggingface.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "from huggingface_hub import HfApi\n", + "\n", "api = HfApi()\n", "\n", "uuid_str = \"kng5efo4\"\n", @@ -18,7 +19,7 @@ " path_in_repo=hf_folder,\n", " repo_id=repo_id,\n", " repo_type=\"model\",\n", - ")\n" + ")" ] } ],