From 758a50b073777028cd0dabcc50049798c2fcd68f Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Wed, 8 May 2024 08:35:22 -0700 Subject: [PATCH] feat: Change eval batch size (#128) * Surface # of eval batches and # of eval sequences * fix formatting * fix print statement accidentally left in --- sae_lens/training/activations_store.py | 6 +++--- sae_lens/training/config.py | 4 ++++ sae_lens/training/evals.py | 10 +++++++--- sae_lens/training/lm_runner.py | 2 ++ sae_lens/training/train_sae_on_language_model.py | 8 ++++++++ 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 25bcb7f8..ed95443a 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -191,12 +191,12 @@ def dataloader(self) -> Iterator[Any]: self._dataloader = self.get_data_loader() return self._dataloader - def get_batch_tokens(self): + def get_batch_tokens(self, batch_size: int | None = None): """ Streams a batch of tokens from a dataset. """ - - batch_size = self.store_batch_size + if not batch_size: + batch_size = self.store_batch_size context_size = self.context_size device = self.device diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 35e72f0d..68f268e1 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -106,6 +106,10 @@ class LanguageModelSAERunnerConfig: dead_feature_threshold: float = 1e-8 + # Evals + n_eval_batches: int = 10 + n_eval_seqs: int | None = None # useful if evals cause OOM + # WANDB log_to_wandb: bool = True log_activations_store_to_wandb: bool = False diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 0cd8f6f8..fb985df4 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -17,6 +17,8 @@ def run_evals( model: HookedRootModule, n_training_steps: int, suffix: str = "", + n_eval_batches: int = 10, + n_eval_seqs: int | None = None, ) -> Mapping[str, Any]: hook_point = sparse_autoencoder.cfg.hook_point hook_point_layer = sparse_autoencoder.hook_point_layer @@ -25,14 +27,15 @@ def run_evals( layer=hook_point_layer ) ### Evals - eval_tokens = activation_store.get_batch_tokens() + eval_tokens = activation_store.get_batch_tokens(n_eval_seqs) # Get Reconstruction Score losses_df = recons_loss_batched( sparse_autoencoder, model, activation_store, - n_batches=10, + n_batches=n_eval_batches, + n_eval_seqs=n_eval_seqs, ) recons_score = losses_df["score"].mean() @@ -100,10 +103,11 @@ def recons_loss_batched( model: HookedRootModule, activation_store: ActivationsStore, n_batches: int = 100, + n_eval_seqs: int | None = None, ): losses = [] for _ in range(n_batches): - batch_tokens = activation_store.get_batch_tokens() + batch_tokens = activation_store.get_batch_tokens(n_eval_seqs) score, loss, recons_loss, zero_abl_loss = get_recons_loss( sparse_autoencoder, model, batch_tokens ) diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index 5a7c0bd9..ed4d0dc5 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -85,6 +85,8 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): wandb_log_frequency=cfg.wandb_log_frequency, eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs, autocast=cfg.autocast, + n_eval_batches=cfg.n_eval_batches, + n_eval_seqs=cfg.n_eval_seqs, ).sae_group if cfg.log_to_wandb: diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 639281a5..dd2a5382 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -188,6 +188,8 @@ def train_sae_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, + n_eval_batches: int = 10, + n_eval_seqs: int | None = None, ) -> SparseAutoencoderDictionary: """ @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. @@ -203,6 +205,8 @@ def train_sae_on_language_model( wandb_log_frequency=wandb_log_frequency, eval_every_n_wandb_logs=eval_every_n_wandb_logs, autocast=autocast, + n_eval_batches=n_eval_batches, + n_eval_seqs=n_eval_seqs, ).sae_group @@ -223,6 +227,8 @@ def train_sae_group_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, + n_eval_batches: int = 10, + n_eval_seqs: int | None = None, ) -> TrainSAEGroupOutput: total_training_tokens = get_total_training_tokens(sae_group=sae_group) _update_sae_lens_training_version(sae_group) @@ -325,6 +331,8 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): model, training_run_state.n_training_steps, suffix=wandb_suffix, + n_eval_batches=n_eval_batches, + n_eval_seqs=n_eval_seqs, ) sparse_autoencoder.train()