From 2a676b210832e789dbb80f33b2d8f747a7209e0f Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Tue, 21 May 2024 09:35:51 +0100 Subject: [PATCH] gemma 2b sae resid post 12. fix ghost grad print --- sae_lens/pretrained_saes.yaml | 2 ++ sae_lens/training/config.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sae_lens/pretrained_saes.yaml b/sae_lens/pretrained_saes.yaml index ef3404b0..fe954981 100644 --- a/sae_lens/pretrained_saes.yaml +++ b/sae_lens/pretrained_saes.yaml @@ -68,3 +68,5 @@ SAE_LOOKUP: path: "gemma_2b_blocks.0.hook_resid_post_16384_anthropic" - id: "blocks.6.hook_resid_post" path: "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr" + - id: "blocks.12.hook_resid_post" + path: "gemma_2b_blocks.12.hook_resid_post_16384" diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index c8ccb54b..1160c3ee 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -243,7 +243,7 @@ def __post_init__(self): f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}" ) - if not self.use_ghost_grads: + if self.use_ghost_grads: print("Using Ghost Grads.") def get_checkpoints_by_step(self) -> tuple[dict[int, str], bool]: