From 11a71e1b95576ef6dc3dbec7eb1c76ce7ca44dfd Mon Sep 17 00:00:00 2001 From: Joseph Bloom Date: Tue, 16 Apr 2024 20:18:32 +0000 Subject: [PATCH] get decoder fine tuning working --- sae_lens/training/activations_store.py | 2 +- sae_lens/training/cache_activations_runner.py | 4 +- sae_lens/training/config.py | 30 +- sae_lens/training/sae_group.py | 4 + sae_lens/training/sparse_autoencoder.py | 16 +- .../training/train_sae_on_language_model.py | 44 +- scripts/run.ipynb | 1011 +++++++---------- .../test_language_model_sae_runner.py | 2 +- tests/unit/helpers.py | 2 +- .../test_train_sae_on_language_model.py | 5 +- tutorials/training_a_sparse_autoencoder.ipynb | 2 +- 11 files changed, 488 insertions(+), 634 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 604f7ba4..f0efa3ff 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -59,7 +59,7 @@ def from_config( context_size=cfg.context_size, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, - total_training_tokens=cfg.total_training_tokens, + total_training_tokens=cfg.training_tokens, store_batch_size=cfg.store_batch_size, train_batch_size=cfg.train_batch_size, prepend_bos=cfg.prepend_bos, diff --git a/sae_lens/training/cache_activations_runner.py b/sae_lens/training/cache_activations_runner.py index 443b399e..db9b470f 100644 --- a/sae_lens/training/cache_activations_runner.py +++ b/sae_lens/training/cache_activations_runner.py @@ -31,11 +31,11 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): else: os.makedirs(activations_store.cached_activations_path) - print(f"Started caching {cfg.total_training_tokens} activations") + print(f"Started caching {cfg.training_tokens} activations") tokens_per_buffer = ( cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer ) - n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer) + n_buffers = math.ceil(cfg.training_tokens / tokens_per_buffer) # for i in tqdm(range(n_buffers), desc="Caching activations"): for i in range(n_buffers): buffer = activations_store.get_buffer(cfg.n_batches_in_buffer) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 2a91b86f..22857e90 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -45,7 +45,8 @@ class LanguageModelSAERunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 + finetuning_tokens: int = 0 store_batch_size: int = 32 train_batch_size: int = 4096 @@ -56,11 +57,20 @@ class LanguageModelSAERunnerConfig: prepend_bos: bool = True # Training Parameters + + ## Batch size + train_batch_size: int = 4096 + + ## Adam adam_beta1: float | list[float] = 0 adam_beta2: float | list[float] = 0.999 + + ## Loss Function mse_loss_normalization: Optional[str] = None l1_coefficient: float | list[float] = 1e-3 lp_norm: float | list[float] = 1 + + ## Learning Rate Schedule lr: float | list[float] = 3e-4 lr_scheduler_name: str | list[str] = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts @@ -71,7 +81,9 @@ class LanguageModelSAERunnerConfig: ) lr_decay_steps: int | list[int] = 0 n_restart_cycles: int | list[int] = 1 # used only for cosineannealingwarmrestarts - train_batch_size: int = 4096 + + ## FineTuning + finetuning_method: Optional[str] = None # scale, decoder or unrotated_decoder # Resampling protocol args use_ghost_grads: bool | list[bool] = ( @@ -111,7 +123,7 @@ def __post_init__(self): ) if self.run_name is None: - self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: raise ValueError( @@ -129,6 +141,12 @@ def __post_init__(self): elif isinstance(self.dtype, str): self.dtype: torch.dtype = DTYPE_MAP[self.dtype] + # if we use decoder fine tuning, we can't be applying b_dec to the input + if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input): + raise ValueError( + "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False." + ) + self.device: str | torch.device = torch.device(self.device) if self.lr_end is None: @@ -144,7 +162,7 @@ def __post_init__(self): if self.verbose: print( - f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" ) # Print out some useful info: n_tokens_per_buffer = ( @@ -156,7 +174,7 @@ def __post_init__(self): f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" ) - total_training_steps = self.total_training_tokens // self.train_batch_size + total_training_steps = self.training_tokens // self.train_batch_size print(f"Total training steps: {total_training_steps}") total_wandb_updates = total_training_steps // self.wandb_log_frequency @@ -209,7 +227,7 @@ class CacheActivationsRunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 store_batch_size: int = 32 train_batch_size: int = 4096 diff --git a/sae_lens/training/sae_group.py b/sae_lens/training/sae_group.py index b7ae75f9..15b651ba 100644 --- a/sae_lens/training/sae_group.py +++ b/sae_lens/training/sae_group.py @@ -133,6 +133,10 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary" if not hasattr(cfg, "model_kwargs"): cfg.model_kwargs = {} sparse_autoencoder = SparseAutoencoder(cfg=cfg) + # add dummy scaling factor to the state dict + group["state_dict"]["scaling_factor"] = torch.ones( + cfg.d_sae, dtype=cfg.dtype, device=cfg.device + ) sparse_autoencoder.load_state_dict(group["state_dict"]) group = cls(cfg) for key in group.autoencoders: diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index 33d89d2a..9272a4d9 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -98,6 +98,11 @@ def __init__( torch.zeros(self.d_in, dtype=self.dtype, device=self.device) ) + # scaling factor for fine-tuning (not to be used in initial training) + self.scaling_factor = nn.Parameter( + torch.ones(self.d_sae, dtype=self.dtype, device=self.device) + ) + self.hook_sae_in = HookPoint() self.hook_hidden_pre = HookPoint() self.hook_hidden_post = HookPoint() @@ -124,7 +129,8 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) sae_out = self.hook_sae_out( einops.einsum( - feature_acts, + feature_acts + * self.scaling_factor, # need to make sure this handled when loading old models. self.W_dec, "... d_sae, d_sae d_in -> ... d_in", ) @@ -330,6 +336,14 @@ def load_from_pretrained(cls, path: str, device: str = "cpu"): with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore for k in f.keys(): tensors[k] = f.get_tensor(k) + + # old saves may not have scaling factors. + if "scaling_factor" not in tensors: + assert isinstance(config.d_sae, int) + tensors["scaling_factor"] = torch.ones( + config.d_sae, dtype=config.dtype, device=config.device + ) + sae.load_state_dict(tensors) return sae diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index efcd3473..59285289 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -17,6 +17,13 @@ from sae_lens.training.sae_group import SparseAutoencoderDictionary from sae_lens.training.sparse_autoencoder import SparseAutoencoder +# used to map between parameters which are updated during finetuning and the config str. +FINETUNING_PARAMETERS = { + "scale": ["scaling_factor"], + "decoder": ["scaling_factor", "W_dec", "b_dec"], + "unrotated_decoder": ["scaling_factor", "b_dec"], +} + def _log_feature_sparsity( feature_sparsity: torch.Tensor, eps: float = 1e-10 @@ -35,6 +42,7 @@ class SAETrainContext: n_frac_active_tokens: int optimizer: Optimizer scheduler: LRScheduler + finetuning: bool = False @property def feature_sparsity(self) -> torch.Tensor: @@ -44,6 +52,21 @@ def feature_sparsity(self) -> torch.Tensor: def log_feature_sparsity(self) -> torch.Tensor: return _log_feature_sparsity(self.feature_sparsity) + def begin_finetuning(self, sae: SparseAutoencoder): + + # finetuning method should be set in the config + # if not, then we don't finetune + if not isinstance(sae.cfg.finetuning_method, str): + return + + for name, param in sae.named_parameters(): + if name in FINETUNING_PARAMETERS[sae.cfg.finetuning_method]: + param.requires_grad = True + else: + param.requires_grad = False + + self.finetuning = True + @dataclass class TrainSAEGroupOutput: @@ -88,10 +111,13 @@ def train_sae_group_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, ) -> TrainSAEGroupOutput: - total_training_tokens = sae_group.cfg.total_training_tokens + total_training_tokens = ( + sae_group.cfg.training_tokens + sae_group.cfg.finetuning_tokens + ) total_training_steps = total_training_tokens // batch_size n_training_steps = 0 n_training_tokens = 0 + started_fine_tuning = False checkpoint_thresholds = [] if n_checkpoints > 0: @@ -180,6 +206,16 @@ def train_sae_group_on_language_model( ) pbar.update(batch_size) + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + if (not started_fine_tuning) and ( + n_training_tokens > sae_group.cfg.training_tokens + ): + started_fine_tuning = True + for name, sparse_autoencoder in sae_group.autoencoders.items(): + ctx = train_contexts[name] + # this should turn grads on for the scaling factor and other parameters. + ctx.begin_finetuning(sae_group.autoencoders[name]) + # save final sae group to checkpoints folder final_checkpoint = _save_checkpoint( sae_group, @@ -248,6 +284,12 @@ def _build_train_context( ) n_frac_active_tokens = 0 + # we don't train the scaling factor (initially) + # set requires grad to false for the scaling factor + for name, param in sae.named_parameters(): + if "scaling_factor" in name: + param.requires_grad = False + optimizer = Adam( sae.parameters(), lr=sae.cfg.lr, diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 8e6cc3a4..824b431c 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -24,14 +24,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using device: mps\n" + "Using device: cuda\n" ] } ], @@ -60,273 +60,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Gelu-2L\n", - "\n", - "An example of a toy language model we're able to train on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### MLP Out" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gelu-2l\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=[16, 32, 64],\n", - " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " l1_coefficient=0.00016,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 100,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=5000,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-4,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GPT2 - Small" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Residual Stream" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "layer = 3\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", - " hook_point_layer=layer,\n", - " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=32, # determines the dimension of the SAE.\n", - " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0004,\n", - " l1_coefficient=0.00008,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " lr_warm_up_steps=5000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=2500,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=10,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70-M" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "import cProfile\n", - "\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=64,\n", - " # Training Parameters\n", - " lr=3e-4,\n", - " l1_coefficient=4e-5,\n", - " train_batch_size=8192,\n", - " context_size=128,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=10_000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=1_000_000 * 800,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_window=2000, # Doesn't currently matter.\n", - " dead_feature_window=40000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=20,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70M Hook Q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"../\")\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.2.attn.hook_q\",\n", - " hook_point_layer=2,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " l1_coefficient=0.003,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=1000, # about 4 million tokens.\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 1500,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_method=\"anthropic\",\n", - " feature_sampling_window=1000, # doesn't do anything currently.\n", - " feature_reinit_scale=0.2,\n", - " resample_batches=8,\n", - " dead_feature_window=60000,\n", - " dead_feature_threshold=1e-5,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " n_checkpoints=15,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tiny Stories" + "# Tiny Stories - 1L" ] }, { @@ -338,49 +72,222 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n", + "Moving model to device: cuda\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_135218-opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Objective value: 1781464.6250: 4%|▍ | 4/100 [00:00<00:00, 206.25it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "135| MSE Loss 0.257 | L1 1.354: 1%| | 552960/50000000 [00:13<19:08, 43042.90it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (v0kr8hz9) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "6104| MSE Loss 0.072 | L1 0.024: : 25001984it [18:07, 22981.57it/s]\n", + "12208| MSE Loss 0.070 | L1 0.024: 100%|█████████▉| 49999872/50000000 [20:15<00:00, 30551.50it/s]" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb94759b99e14133aece0058a423e305", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='128.448 MB of 128.448 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▄▅▆▆▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▅▂▅▅█▆▆▇▅▅▆▅▄▅▅▄▅▃▄▄▄▄▃▆▆▄▁▄▆▃▆▃▅▆▂▃▆▄▃▅
metrics/ce_loss_with_sae█▅▅▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▆▂▂▁▇▄▇▆▃▃▅▅▆▇▃▆▃▄▄▄█▂▂▄▄▃▁▅▄▄▂▅▄█▃▄▄▄▅▆
metrics/explained_variance▁▄▄▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/explained_variance_std▆▁▄▇███▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▄▅▄▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/l0█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm█▁▂▃▄▄▃▅▅▄▅▅▄▅▄▅▅▅▅▄▆▅▆▅▆▆▅▆▆▇█▇▇▆▆▇▆▆▇▇
metrics/l2_ratio█▁▂▃▃▃▃▄▄▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆▅▆▆▆▆▆▆▅▆▆▅▆▆▆
metrics/mean_log10_feature_sparsity█▅▃▃▂▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▁▂▆███▅██
sparsity/below_1e-6▁▁▁▁▁▁█▇█▂▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▅▅▁▁▁▁▅▁▁▁▁▁▅▅▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▂▂▂▃▂▄▄▅▇▇▇▇██▆▇▆███▆▅▅█▇▇▇▇██

Run summary:


details/current_learning_rate0.0008
details/n_training_tokens49971200
losses/ghost_grad_loss0.0
losses/l1_loss15.59199
losses/mse_loss0.07019
losses/overall_loss0.09358
metrics/CE_loss_score0.86351
metrics/ce_loss_with_ablation8.5168
metrics/ce_loss_with_sae3.00156
metrics/ce_loss_without_sae2.12988
metrics/explained_variance0.56934
metrics/explained_variance_std0.14386
metrics/l019.32129
metrics/l2_norm15.93428
metrics/l2_ratio0.86545
metrics/mean_log10_feature_sparsity-4.81775
sparsity/below_1e-56329
sparsity/below_1e-681
sparsity/dead_features0
sparsity/mean_passes_since_fired29.02307

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 at: https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl
View project at: https://wandb.ai/jbloom/sae_lens_tutorial
Synced 7 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240416_135218-opqs9dgl/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12208| MSE Loss 0.070 | L1 0.024: : 50003968it [20:27, 30551.50it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (opqs9dgl) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n" + ] + } + ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.mlp.hook_post\",\n", - " hook_point_layer=1,\n", - " d_in=256,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_point=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_point_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", " is_dataset_tokenized=True,\n", + " \n", " # SAE Parameters\n", - " expansion_factor=16,\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"geometric_median\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder to the input.\n", + " \n", " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", + " lr=0.0008, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=10000, # this can help avoid too many dead features initially.\n", + " l1_coefficient=0.0015, # will control how sparse the feature activations are\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=128, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", + " \n", " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=1_000_000 * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 25,\n", " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", + " \n", + " \n", + " # Resampling protocol\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", " wandb_log_frequency=10,\n", " # Misc\n", " device=device,\n", @@ -390,82 +297,145 @@ " dtype=torch.float32,\n", ")\n", "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder_dictionary = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT2 - Small" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Hook Z\n", - "\n" + "### Residual Stream" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", + "Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08\n", + "n_tokens_per_buffer (millions): 1.048576\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", + "Total training steps: 48828\n", "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", - "Moving model to device: mps\n" + "n_tokens_per_feature_sampling_window (millions): 2621.44\n", + "n_tokens_per_dead_feature_window (millions): 5242.88\n", + "We will reset the sparsity calculation 19 times.\n", + "Number tokens in sparsity calculation window: 1.02e+07\n", + "Loaded pretrained model gpt2-small into HookedTransformer\n", + "Moving model to device: cuda\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fee8922d83f04003a2f1441eeb30200d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/73 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea686292dff7449a9846fcfa29d6ff74", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.064 MB of 0.064 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▂▂▃▄▄▅▅▆▆▇▇███████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇▇█████████
metrics/ce_loss_with_ablation▄▂▁▄▄▂▄▄▁▄▁▄▂█
metrics/ce_loss_with_sae█▄▂▂▂▂▁▂▁▁▁▁▁▁
metrics/ce_loss_without_sae▂▇▄█▅▃▄▇▆▅▄▆▁▁
metrics/explained_variance▁▃▄▆▆▇▇▇▇▇▇▇████████████████████████████
metrics/explained_variance_std█▄█▇▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▅▆▇▇▇█▇██████
metrics/l2_ratio▁▅▆▇▇▇▇▇██████
metrics/mean_log10_feature_sparsity█▃▂▁▁
sparsity/below_1e-5▁▁▅██
sparsity/below_1e-6▁▁▁▄█
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▅▆█
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▇▇█

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens59801600
losses/ghost_grad_loss0.0
losses/l1_loss160.66861
losses/mse_loss1.68098
losses/overall_loss2.96633
metrics/CE_loss_score0.96258
metrics/ce_loss_with_ablation11.49633
metrics/ce_loss_with_sae3.62324
metrics/ce_loss_without_sae3.3166
metrics/explained_variance0.78709
metrics/explained_variance_std0.05978
metrics/l050.03076
metrics/l2_norm102.32782
metrics/l2_ratio0.8864
metrics/mean_log10_feature_sparsity-5.31744
sparsity/below_1e-519194
sparsity/below_1e-611736
sparsity/dead_features60
sparsity/mean_passes_since_fired640.44727

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "wandb version 0.16.5 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/pq5q3x9s
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)" ], "text/plain": [ "" @@ -477,7 +447,7 @@ { "data": { "text/html": [ - "Tracking run with wandb version 0.16.3" + "Find logs at: ./wandb/run-20240416_155117-pq5q3x9s/logs" ], "text/plain": [ "" @@ -489,7 +459,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" + "Successfully finished last run (ID:pq5q3x9s). Initializing new run:
" ], "text/plain": [ "" @@ -498,10 +468,24 @@ "metadata": {}, "output_type": "display_data" }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fd02fd0295cc4afda9bb0e1367c87f84", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112805799995032, max=1.0…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ - "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" + "Tracking run with wandb version 0.16.6" ], "text/plain": [ "" @@ -513,7 +497,7 @@ { "data": { "text/html": [ - " View project at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests" + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_165827-vbwoyzi8" ], "text/plain": [ "" @@ -525,7 +509,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" + "Syncing run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -535,91 +519,56 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", - "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", - "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", - "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", - "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" - ] + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/gpt2_small_experiments_april" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "name": "stdout", + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", "output_type": "stream", "text": [ - "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" + "Objective value: 46608928.0000: 2%|▏ | 2/100 [00:00<00:01, 55.75it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "120| MSE Loss 31.151 | L1 65.750: 0%| | 487424/300000000 [00:15<1:28:10, 56617.16it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (4elmsny3) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "2407| MSE Loss 0.070 | L1 0.027: 20%|█▉ | 9859072/50000000 [3:33:05<14:27:36, 771.10it/s]\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:02, 54947.70it/s] " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1dffd84a387d4cf48100fbe143287481", + "model_id": "6111ba99afb144ae82bab7723efb2c86", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" + "VBox(children=(Label(value='721.959 MB of 721.959 MB uploaded (0.005 MB deduped)\\r'), FloatProgress(value=1.0,…" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" - ] - }, { "data": { "text/html": [ @@ -628,7 +577,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" + "

Run history:


details/current_learning_rate▁▄▆█████████████████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇████████████████████████████████████
metrics/ce_loss_with_ablation▅▃▄▃▅▄▄▃▄▄▂▂▂▃▄▃▂▆▃▆▃▆▃▆▅▁▁▅▂▃▃▄▃▅▄▃█▄▃▆
metrics/ce_loss_with_sae█▄▂▂▂▂▂▁▂▁▁▂▂▁▁▁▂▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▁▅▆▂▆▄▅▁▄▄▄▇▆▄▃▄▇▄▂▆▆█▄▅▅▄▄▄▅▄▅▅▃▅▄▅▂▃▂▄
metrics/explained_variance▁▆▇▇▇▇▇█████████████████████████████████
metrics/explained_variance_std█▄▃▂▂▂▂▁▂▁▂▁▁▂▂▁▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▆▆▇▆▆▆▇▆█████████████
metrics/l2_ratio▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆█████████████
metrics/mean_log10_feature_sparsity█▅▄▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▅██████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
sparsity/below_1e-6▁▁▁▃▅▆▇██████████████████████
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▂▂▃▄▅▆▆▇▇▇▇███████████████████
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens299827200
losses/ghost_grad_loss0.0
losses/l1_loss162.07342
losses/mse_loss1.42934
losses/overall_loss2.72593
metrics/CE_loss_score0.97257
metrics/ce_loss_with_ablation11.42603
metrics/ce_loss_with_sae3.61949
metrics/ce_loss_without_sae3.39944
metrics/explained_variance0.82112
metrics/explained_variance_std0.0526
metrics/l050.53198
metrics/l2_norm108.35806
metrics/l2_ratio0.94604
metrics/mean_log10_feature_sparsity-7.89094
sparsity/below_1e-518079
sparsity/below_1e-618075
sparsity/dead_features16912
sparsity/mean_passes_since_fired27024.85938

" ], "text/plain": [ "" @@ -640,7 +589,7 @@ { "data": { "text/html": [ - " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 15 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" @@ -652,7 +601,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" + "Find logs at: ./wandb/run-20240416_165827-vbwoyzi8/logs" ], "text/plain": [ "" @@ -665,223 +614,56 @@ "name": "stderr", "output_type": "stream", "text": [ - "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:12, 54947.70it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (vbwoyzi8) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", " lambda data: self._console_raw_callback(\"stderr\", data),\n" ] } ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.attn.hook_z\",\n", - " hook_point_layer=1,\n", - " d_in=64,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", - " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Toy Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.toy_model_runner import (\n", - " SAEToyModelRunnerConfig,\n", - " toy_model_sae_runner,\n", - ")\n", - "\n", - "\n", - "cfg = SAEToyModelRunnerConfig(\n", - " # Model Details\n", - " n_features=200,\n", - " n_hidden=5,\n", - " n_correlated_pairs=0,\n", - " n_anticorrelated_pairs=0,\n", - " feature_probability=0.025,\n", - " model_training_steps=10_000,\n", - " # SAE Parameters\n", - " d_sae=240,\n", - " l1_coefficient=0.001,\n", - " # SAE Train Config\n", - " train_batch_size=1028,\n", - " feature_sampling_window=3_000,\n", - " dead_feature_window=1_000,\n", - " feature_reinit_scale=0.5,\n", - " total_training_tokens=4096 * 300,\n", - " # Other parameters\n", - " log_to_wandb=True,\n", - " wandb_project=\"sae-training-test\",\n", - " wandb_log_frequency=5,\n", - " device=\"mps\",\n", - ")\n", - "\n", - "trained_sae = toy_model_sae_runner(cfg)\n", - "\n", - "assert trained_sae is not None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run caching of activations to disk" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "\n", - "from sae_lens.training.config import CacheActivationsRunnerConfig\n", - "from sae_lens.training.cache_activations_runner import cache_activations_runner\n", - "\n", - "cfg = CacheActivationsRunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.attn.hook_q\",\n", - " hook_point_layer=10,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " cached_activations_path=\"../activations/\",\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=16,\n", - " total_training_tokens=500_000_000,\n", - " store_batch_size=32,\n", - " # Activation caching shuffle parameters\n", - " n_shuffles_final=16,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "cache_activations_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train an SAE using the cached activations stored on disk\n", - "Pass `use_cached_activations=True` into the config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", "\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.hook_resid_pre\",\n", - " hook_point_layer=11,\n", + " hook_point=\"blocks.8.hook_resid_pre\",\n", + " hook_point_layer=8,\n", " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " use_cached_activations=True,\n", + " dataset_path=\"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " prepend_bos=True, # should experiment with turning this off.\n", " # SAE Parameters\n", - " expansion_factor=64, # determines the dimension of the SAE.\n", + " expansion_factor=32, # determines the dimension of the SAE.\n", + " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", + " apply_b_dec_to_input=False,\n", " # Training Parameters\n", - " lr=1e-5,\n", - " l1_coefficient=5e-4,\n", - " lr_scheduler_name=None,\n", + " adam_beta1=0,\n", + " adam_beta2=0.999,\n", + " lr=0.0004,\n", + " l1_coefficient=0.008,\n", + " lr_scheduler_name=\"constant\",\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=256,\n", + " lr_warm_up_steps=5000,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=200_000,\n", + " n_batches_in_buffer=128,\n", + " training_tokens=1_000_000 * 200, # 200M tokens seems doable overnight.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 100,\n", " store_batch_size=32,\n", + " \n", " # Resampling protocol\n", - " feature_sampling_method=\"l2\",\n", - " feature_sampling_window=1000,\n", - " feature_reinit_scale=0.2,\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=2500,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-7,\n", + " dead_feature_threshold=1e-8,\n", + " \n", " # WANDB\n", " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_gpt2_small\",\n", + " wandb_project=\"gpt2_small_experiments_april\",\n", " wandb_entity=None,\n", - " wandb_log_frequency=50,\n", + " wandb_log_frequency=100,\n", " # Misc\n", - " device=\"mps\",\n", + " device=device,\n", " seed=42,\n", " n_checkpoints=5,\n", " checkpoint_path=\"checkpoints\",\n", @@ -890,13 +672,6 @@ "\n", "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -915,7 +690,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index e0030ce5..e1c878c0 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -30,7 +30,7 @@ def test_language_model_sae_runner(): context_size=128, # Activation Store Parameters n_batches_in_buffer=24, - total_training_tokens=1_000_000 * 10, + training_tokens=1_000_000 * 10, store_batch_size=32, # Resampling protocol use_ghost_grads=True, diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 4d18b6ab..0f172ab5 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -32,7 +32,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: feature_sampling_window=50, dead_feature_threshold=1e-7, n_batches_in_buffer=2, - total_training_tokens=1_000_000, + training_tokens=1_000_000, store_batch_size=4, log_to_wandb=False, wandb_project="test_project", diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..98898806 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -26,6 +26,7 @@ from tests.unit.helpers import build_sae_cfg +# TODO: Address why we have this code here rather than importing it. def build_train_ctx( sae: SparseAutoencoder, act_freq_scores: Tensor | None = None, @@ -310,11 +311,11 @@ def test_train_sae_group_on_language_model__runs( cfg = build_sae_cfg( checkpoint_path=checkpoint_dir, train_batch_size=32, - total_training_tokens=100, + training_tokens=100, context_size=8, ) # just a tiny datast which will run quickly - dataset = Dataset.from_list([{"text": "hello world"}] * 1000) + dataset = Dataset.from_list([{"text": "hello world"}] * 2000) activation_store = ActivationsStore.from_config(ts_model, cfg, dataset=dataset) sae_group = SparseAutoencoderDictionary(cfg) res = train_sae_group_on_language_model( diff --git a/tutorials/training_a_sparse_autoencoder.ipynb b/tutorials/training_a_sparse_autoencoder.ipynb index 73f4ccdd..17eed89c 100644 --- a/tutorials/training_a_sparse_autoencoder.ipynb +++ b/tutorials/training_a_sparse_autoencoder.ipynb @@ -335,7 +335,7 @@ " context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", " # Activation Store Parameters\n", " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " total_training_tokens=1_000_000\n", + " training_tokens=1_000_000\n", " * 50, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", " store_batch_size=16,\n", " # Resampling protocol\n",