diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index ed8648aa..36fb61dc 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -377,4 +377,6 @@ def _per_item_mse_loss_with_target_norm( """ target_centered = target - target.mean(dim=0, keepdim=True) normalization = target_centered.norm(dim=-1, keepdim=True) - return torch.nn.functional.mse_loss(preds, target, reduction="none") / normalization + return torch.nn.functional.mse_loss(preds, target, reduction="none") / ( + normalization + 1e-6 + )