diff --git a/sae_training/evals.py b/sae_training/evals.py index 276608e1..c06017fa 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -26,10 +26,15 @@ def run_evals( eval_tokens = activation_store.get_batch_tokens() # Get Reconstruction Score - recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss( - sparse_autoencoder, model, activation_store + losses_df = recons_loss_batched( + sparse_autoencoder, model, activation_store, n_batches = 10, ) + recons_score = losses_df["score"].mean() + ntp_loss = losses_df["loss"].mean() + recons_loss = losses_df["recons_loss"].mean() + zero_abl_loss = losses_df["zero_abl_loss"].mean() + # get cache _, cache = model.run_with_cache( eval_tokens, @@ -144,8 +149,9 @@ def head_replacement_hook(activations, hook): def recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches=100): losses = [] for _ in tqdm(range(n_batches)): + batch_tokens = activation_store.get_batch_tokens() score, loss, recons_loss, zero_abl_loss = get_recons_loss( - sparse_autoencoder, model, activation_store + sparse_autoencoder, model, batch_tokens ) losses.append( (