diff --git a/tests/unit/training/test_session_loader.py b/tests/unit/training/test_session_loader.py index cb3c2102..9ae1e48d 100644 --- a/tests/unit/training/test_session_loader.py +++ b/tests/unit/training/test_session_loader.py @@ -9,6 +9,7 @@ from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.sae_group import SparseAutoencoderDictionary # from sae_lens.training.sae_group import SAETrainingGroup from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader @@ -152,7 +153,7 @@ def test_load_pretrained_sae_from_huggingface(): path=folder_path ) assert isinstance(model, HookedTransformer) - assert isinstance(sae, SparseAutoencoder) + assert isinstance(sae, SparseAutoencoderDictionary) assert isinstance(activation_store, ActivationsStore) assert sae.cfg.hook_point_layer == layer assert sae.cfg.model_name == "gpt2-small"