From ba41f327364901c40c6613a300e1ceaafe67e8fa Mon Sep 17 00:00:00 2001 From: David Chanin Date: Fri, 5 Apr 2024 18:07:45 +0100 Subject: [PATCH] setting up sae_lens package and auto-deploy with semantic-release --- .flake8 | 2 +- .github/workflows/{tests.yml => build.yml} | 36 +- README.md | 16 +- docs/about/citation.md | 2 +- docs/index.md | 18 +- docs/installation.md | 4 +- docs/reference/language_models.md | 8 +- docs/reference/misc.md | 6 +- docs/reference/runners.md | 2 +- docs/reference/toy_models.md | 6 +- docs/training_saes.md | 3 +- makefile | 4 +- mkdocs.yml | 5 +- pyproject.toml | 17 +- sae_lens/__init__.py | 24 + .../analysis}/__init__.py | 0 .../analysis}/dashboard_runner.py | 2 +- .../analysis}/feature_statistics.py | 2 +- .../analysis}/neuronpedia_runner.py | 2 +- .../analysis}/toolkit.py | 2 +- {sae_analysis => sae_lens/analysis}/tsea.py | 0 .../training}/__init__.py | 0 .../training}/activations_store.py | 0 .../training}/cache_activations_runner.py | 6 +- {sae_training => sae_lens/training}/config.py | 0 {sae_training => sae_lens/training}/evals.py | 4 +- .../training}/geometric_median.py | 0 .../training}/lm_runner.py | 8 +- {sae_training => sae_lens/training}/optim.py | 0 .../training}/sae_group.py | 28 +- .../training/session_loader.py | 36 +- .../training}/sparse_autoencoder.py | 2 +- .../training}/toy_model_runner.py | 8 +- .../training}/toy_models.py | 0 .../training}/train_sae_on_language_model.py | 12 +- .../training}/train_sae_on_toy_model.py | 2 +- sae_lens/training/utils.py | 43 + scripts/run.ipynb | 1722 ++++++++--------- .../test_language_model_sae_runner.py | 4 +- tests/benchmark/test_toy_model_sae_runner.py | 5 +- tests/unit/helpers.py | 2 +- .../{ => training}/test_activations_store.py | 2 +- tests/unit/{ => training}/test_optim.py | 2 +- .../test_session_loader.py} | 28 +- .../{ => training}/test_sparse_autoencoder.py | 4 +- .../test_train_sae_on_language_model.py | 10 +- tutorials/evaluating_your_sae.ipynb | 734 +++---- tutorials/generating_sae_dashboards.ipynb | 266 +-- tutorials/logits_lens_with_features.ipynb | 1608 +++++++-------- .../generating_neuronpedia_outputs.ipynb | 2 +- tutorials/neuronpedia/np_runner_batch.py | 2 +- 51 files changed, 2407 insertions(+), 2294 deletions(-) rename .github/workflows/{tests.yml => build.yml} (55%) create mode 100644 sae_lens/__init__.py rename {sae_analysis => sae_lens/analysis}/__init__.py (100%) rename {sae_analysis => sae_lens/analysis}/dashboard_runner.py (99%) rename {sae_analysis => sae_lens/analysis}/feature_statistics.py (98%) rename {sae_analysis => sae_lens/analysis}/neuronpedia_runner.py (99%) rename {sae_analysis => sae_lens/analysis}/toolkit.py (96%) rename {sae_analysis => sae_lens/analysis}/tsea.py (100%) rename {sae_training => sae_lens/training}/__init__.py (100%) rename {sae_training => sae_lens/training}/activations_store.py (100%) rename {sae_training => sae_lens/training}/cache_activations_runner.py (92%) rename {sae_training => sae_lens/training}/config.py (100%) rename {sae_training => sae_lens/training}/evals.py (97%) rename {sae_training => sae_lens/training}/geometric_median.py (100%) rename {sae_training => sae_lens/training}/lm_runner.py (80%) rename {sae_training => sae_lens/training}/optim.py (100%) rename {sae_training => sae_lens/training}/sae_group.py (82%) rename sae_training/utils.py => sae_lens/training/session_loader.py (69%) rename {sae_training => sae_lens/training}/sparse_autoencoder.py (99%) rename {sae_training => sae_lens/training}/toy_model_runner.py (92%) rename {sae_training => sae_lens/training}/toy_models.py (100%) rename {sae_training => sae_lens/training}/train_sae_on_language_model.py (97%) rename {sae_training => sae_lens/training}/train_sae_on_toy_model.py (98%) create mode 100644 sae_lens/training/utils.py rename tests/unit/{ => training}/test_activations_store.py (99%) rename tests/unit/{ => training}/test_optim.py (99%) rename tests/unit/{test_utils.py => training/test_session_loader.py} (78%) rename tests/unit/{ => training}/test_sparse_autoencoder.py (98%) rename tests/unit/{ => training}/test_train_sae_on_language_model.py (97%) diff --git a/.flake8 b/.flake8 index 6c0d64e4..6ed7fa91 100644 --- a/.flake8 +++ b/.flake8 @@ -5,4 +5,4 @@ max-complexity = 25 extend-select = E9, F63, F7, F82 show-source = true statistics = true -exclude = ./sae_training/geom_median/, ./wandb/*, ./research/wandb/* +exclude = ./wandb/*, ./research/wandb/* diff --git a/.github/workflows/tests.yml b/.github/workflows/build.yml similarity index 55% rename from .github/workflows/tests.yml rename to .github/workflows/build.yml index 6f34f6ad..598686d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/build.yml @@ -54,9 +54,43 @@ jobs: run: poetry run pyright - name: Run Unit Tests # Would use make, but want cov report in xml format - run: poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit --cov-report=xml + run: poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/unit --cov-report=xml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 with: token: ${{ secrets.CODECOV_TOKEN }} slug: jbloomAus/mats_sae_training + + release: + + needs: build + permissions: + contents: write + id-token: write + # https://github.community/t/how-do-i-specify-job-dependency-running-in-another-workflow/16482 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') + runs-on: ubuntu-latest + concurrency: release + environment: + name: pypi + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Semantic Release + id: release + uses: python-semantic-release/python-semantic-release@v8.0.7 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + if: steps.release.outputs.released == 'true' + - name: Publish package distributions to GitHub Releases + uses: python-semantic-release/upload-to-gh-release@main + if: steps.release.outputs.released == 'true' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index eba80ae1..bd244faa 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ Screenshot 2024-03-21 at 3 08 28 pm -# MATS SAE Training +# SAELens Training [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![build](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml/badge.svg)](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml) [![Deploy Docs](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/deploy_docs.yml/badge.svg)](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/deploy_docs.yml) [![codecov](https://codecov.io/gh/jbloomAus/mats_sae_training/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/mats_sae_training) -The MATS SAE training codebase (we'll rename it soon) exists to help researchers: +The SAELens training codebase (we'll rename it soon) exists to help researchers: - Train sparse autoencoders. - Analyse sparse autoencoders and neural network internals. - Generate insights which make it easier to create safe and aligned AI systems. @@ -27,7 +27,7 @@ poetry install ```python import torch -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens import LMSparseAutoencoderSessionloader from huggingface_hub import hf_hub_download layer = 8 # pick a layer you want. @@ -88,8 +88,8 @@ Making the code accessible: This involves tasks like turning the code base into The codebase contains 2 folders worth caring about: -- sae_training: The main body of the code is here. Everything required for training SAEs. -- sae_analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types. +- training: The main body of the code is here. Everything required for training SAEs. +- analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types. Some other folders: @@ -123,8 +123,8 @@ import sys os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB__SERVICE_WAIT"] = "300" -from sae_training.config import LanguageModelSAERunnerConfig -from sae_training.lm_runner import language_model_sae_runner +from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.lm_runner import language_model_sae_runner cfg = LanguageModelSAERunnerConfig( @@ -186,7 +186,7 @@ Once your SAE is trained, the final SAE weights will be saved to wandb and are l - An activations loader: from which you can get randomly sampled activations or batches of tokens from the dataset you used to train the SAE. (more on this in the tutorial) ```python -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens import LMSparseAutoencoderSessionloader path ="path/to/sparse_autoencoder.pt" model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained( diff --git a/docs/about/citation.md b/docs/about/citation.md index f9d81c27..8eb8035c 100644 --- a/docs/about/citation.md +++ b/docs/about/citation.md @@ -2,7 +2,7 @@ ``` @misc{bloom2024saetrainingcodebase, - title = {MATS SAE Training + title = {SAELens Training author = {Joseph Bloom}, year = {2024}, howpublished = {\url{}}, diff --git a/docs/index.md b/docs/index.md index 286a22b0..e6bcb2a8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,12 +1,12 @@ Screenshot 2024-03-21 at 3 08 28 pm -# MATS SAE Training +# SAELens [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![build](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml/badge.svg)](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml) [![Deploy Docs](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/deploy_docs.yml/badge.svg)](/~https://github.com/jbloomAus/mats_sae_training/actions/workflows/deploy_docs.yml) [![codecov](https://codecov.io/gh/jbloomAus/mats_sae_training/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/mats_sae_training) -The MATS SAE training codebase (we'll rename it soon) exists to help researchers: +The SAELens training codebase exists to help researchers: - Train sparse autoencoders. - Analyse sparse autoencoders and neural network internals. @@ -16,12 +16,10 @@ The MATS SAE training codebase (we'll rename it soon) exists to help researchers ## Quick Start -### Set Up - -This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run: +### Installation ``` -poetry install +pip install sae-lens ``` ### Loading Sparse Autoencoders from Huggingface @@ -30,7 +28,7 @@ poetry install ```python import torch -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens import LMSparseAutoencoderSessionloader from huggingface_hub import hf_hub_download layer = 8 # pick a layer you want. @@ -61,8 +59,8 @@ We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbW The codebase contains 2 folders worth caring about: -- sae_training: The main body of the code is here. Everything required for training SAEs. -- sae_analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types. +- training: The main body of the code is here. Everything required for training SAEs. +- analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types. Some other folders: @@ -78,7 +76,7 @@ Once your SAE is trained, the final SAE weights will be saved to wandb and are l - An activations loader: from which you can get randomly sampled activations or batches of tokens from the dataset you used to train the SAE. (more on this in the tutorial) ```python -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens import LMSparseAutoencoderSessionloader path ="path/to/sparse_autoencoder.pt" model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained( diff --git a/docs/installation.md b/docs/installation.md index 72d98ac1..878ce88a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,7 +1,7 @@ # Installation -This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run: +This package is available on PyPI. You can install it via pip: ``` -poetry install +pip install sae-lens ``` diff --git a/docs/reference/language_models.md b/docs/reference/language_models.md index 0d4c1583..4be4dcc6 100644 --- a/docs/reference/language_models.md +++ b/docs/reference/language_models.md @@ -1,9 +1,9 @@ # Language Models -::: sae_training.lm_runner +::: sae_lens.training.lm_runner -::: sae_training.train_sae_on_language_model +::: sae_lens.training.train_sae_on_language_model -::: sae_training.sparse_autoencoder +::: sae_lens.training.sparse_autoencoder -::: sae_training.activations_store +::: sae_lens.training.activations_store diff --git a/docs/reference/misc.md b/docs/reference/misc.md index 215aa6dc..2f88fb62 100644 --- a/docs/reference/misc.md +++ b/docs/reference/misc.md @@ -1,7 +1,7 @@ # Misc -::: sae_training.config +::: sae_lens.training.config -::: sae_training.utils +::: sae_lens.training.session_loader -::: sae_training.optim +::: sae_lens.training.optim diff --git a/docs/reference/runners.md b/docs/reference/runners.md index 3be77f12..2116d1cf 100644 --- a/docs/reference/runners.md +++ b/docs/reference/runners.md @@ -1,3 +1,3 @@ # Runners -::: sae_training.lm_runner +::: sae_lens.training.lm_runner diff --git a/docs/reference/toy_models.md b/docs/reference/toy_models.md index 5c0dc419..df4770c4 100644 --- a/docs/reference/toy_models.md +++ b/docs/reference/toy_models.md @@ -1,6 +1,6 @@ -::: sae_training.train_sae_on_toy_model +::: sae_lens.training.train_sae_on_toy_model -::: sae_training.toy_model_runner +::: sae_lens.training.toy_model_runner -::: sae_training.toy_models +::: sae_lens.training.toy_models diff --git a/docs/training_saes.md b/docs/training_saes.md index f0e5962f..632eefbd 100644 --- a/docs/training_saes.md +++ b/docs/training_saes.md @@ -23,8 +23,7 @@ import sys os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["WANDB__SERVICE_WAIT"] = "300" -from sae_training.config import LanguageModelSAERunnerConfig -from sae_training.lm_runner import language_model_sae_runner +from sae_lens import LanguageModelSAERunnerConfig, language_model_sae_runner cfg = LanguageModelSAERunnerConfig( diff --git a/makefile b/makefile index cc45a8f3..5689dd01 100644 --- a/makefile +++ b/makefile @@ -15,10 +15,10 @@ test: make acceptance-test unit-test: - poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit + poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/unit acceptance-test: - poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance + poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/acceptance check-ci: make check-format diff --git a/mkdocs.yml b/mkdocs.yml index dda52371..a7e7d0f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,4 +1,4 @@ -site_name: MATS SAE Training +site_name: SAELens Training site_description: Docs for Sparse Autoencoder Training Library site_author: Joseph Bloom repo_url: http://github.com/jbloomAus/mats_sae_training/ @@ -63,8 +63,7 @@ plugins: - mkdocstrings: custom_templates: null watch: - - sae_training/ # Replace with the path to your Python code - - sae_analysis/ # Replace with the path to your Python code + - sae_lens/ # Replace with the path to your Python code markdown_extensions: diff --git a/pyproject.toml b/pyproject.toml index 0d65e28f..ddba75ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [tool.poetry] -name = "mats_sae_training" +name = "sae-lens" version = "0.1.0" -description = "Training Sparse Autoencoders (SAEs)" +description = "Training and Analyzing Sparse Autoencoders (SAEs)" authors = ["Joseph Bloom"] readme = "README.md" -packages = [{include = "sae_analysis"}, {include = "sae_training"}] +packages = [{include = "sae_lens"}] [tool.poetry.dependencies] python = "^3.10" @@ -42,8 +42,6 @@ pyright = "^1.1.351" profile = "black" [tool.pyright] -exclude = ["./sae_training/geom_median/"] - typeCheckingMode = "strict" reportMissingTypeStubs = "none" reportUnknownMemberType = "none" @@ -59,3 +57,12 @@ reportPrivateUsage = "none" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + + +[tool.semantic_release] +version_variables = [ + "sae_lens/__init__.py:__version__", + "pyproject.toml:version", +] +branch = "main" +build_command = "pip install poetry && poetry build" \ No newline at end of file diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py new file mode 100644 index 00000000..060c8cdf --- /dev/null +++ b/sae_lens/__init__.py @@ -0,0 +1,24 @@ +__version__ = "0.1.0" + +from .training.activations_store import ActivationsStore +from .training.cache_activations_runner import cache_activations_runner +from .training.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig +from .training.evals import run_evals +from .training.lm_runner import language_model_sae_runner +from .training.sae_group import SAEGroup +from .training.session_loader import LMSparseAutoencoderSessionloader +from .training.sparse_autoencoder import SparseAutoencoder +from .training.train_sae_on_language_model import train_sae_group_on_language_model + +__all__ = [ + "LanguageModelSAERunnerConfig", + "CacheActivationsRunnerConfig", + "LMSparseAutoencoderSessionloader", + "SparseAutoencoder", + "SAEGroup", + "run_evals", + "language_model_sae_runner", + "cache_activations_runner", + "ActivationsStore", + "train_sae_group_on_language_model", +] diff --git a/sae_analysis/__init__.py b/sae_lens/analysis/__init__.py similarity index 100% rename from sae_analysis/__init__.py rename to sae_lens/analysis/__init__.py diff --git a/sae_analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py similarity index 99% rename from sae_analysis/dashboard_runner.py rename to sae_lens/analysis/dashboard_runner.py index 7fd94d62..76bd995d 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -17,7 +17,7 @@ from tqdm import tqdm import wandb -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader class DashboardRunner: diff --git a/sae_analysis/feature_statistics.py b/sae_lens/analysis/feature_statistics.py similarity index 98% rename from sae_analysis/feature_statistics.py rename to sae_lens/analysis/feature_statistics.py index ecf1b516..59fa8b93 100644 --- a/sae_analysis/feature_statistics.py +++ b/sae_lens/analysis/feature_statistics.py @@ -3,7 +3,7 @@ from tqdm import tqdm from transformer_lens import HookedTransformer -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.sparse_autoencoder import SparseAutoencoder @torch.no_grad() diff --git a/sae_analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py similarity index 99% rename from sae_analysis/neuronpedia_runner.py rename to sae_lens/analysis/neuronpedia_runner.py index 46f441a6..6b3e8d3e 100644 --- a/sae_analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -13,7 +13,7 @@ from sae_vis.data_storing_fns import FeatureVisParams from tqdm import tqdm -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader OUT_OF_RANGE_TOKEN = "<|outofrange|>" diff --git a/sae_analysis/toolkit.py b/sae_lens/analysis/toolkit.py similarity index 96% rename from sae_analysis/toolkit.py rename to sae_lens/analysis/toolkit.py index 5a8842a4..4f12025b 100644 --- a/sae_analysis/toolkit.py +++ b/sae_lens/analysis/toolkit.py @@ -3,7 +3,7 @@ import torch from huggingface_hub import hf_hub_download -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.sparse_autoencoder import SparseAutoencoder def get_all_gpt2_small_saes() -> ( diff --git a/sae_analysis/tsea.py b/sae_lens/analysis/tsea.py similarity index 100% rename from sae_analysis/tsea.py rename to sae_lens/analysis/tsea.py diff --git a/sae_training/__init__.py b/sae_lens/training/__init__.py similarity index 100% rename from sae_training/__init__.py rename to sae_lens/training/__init__.py diff --git a/sae_training/activations_store.py b/sae_lens/training/activations_store.py similarity index 100% rename from sae_training/activations_store.py rename to sae_lens/training/activations_store.py diff --git a/sae_training/cache_activations_runner.py b/sae_lens/training/cache_activations_runner.py similarity index 92% rename from sae_training/cache_activations_runner.py rename to sae_lens/training/cache_activations_runner.py index a61dec48..5e192148 100644 --- a/sae_training/cache_activations_runner.py +++ b/sae_lens/training/cache_activations_runner.py @@ -5,9 +5,9 @@ from tqdm import tqdm from transformer_lens import HookedTransformer -from sae_training.activations_store import ActivationsStore -from sae_training.config import CacheActivationsRunnerConfig -from sae_training.utils import shuffle_activations_pairwise +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.config import CacheActivationsRunnerConfig +from sae_lens.training.utils import shuffle_activations_pairwise def cache_activations_runner(cfg: CacheActivationsRunnerConfig): diff --git a/sae_training/config.py b/sae_lens/training/config.py similarity index 100% rename from sae_training/config.py rename to sae_lens/training/config.py diff --git a/sae_training/evals.py b/sae_lens/training/evals.py similarity index 97% rename from sae_training/evals.py rename to sae_lens/training/evals.py index 932b6a9b..b1b346e1 100644 --- a/sae_training/evals.py +++ b/sae_lens/training/evals.py @@ -8,8 +8,8 @@ from transformer_lens.utils import get_act_name import wandb -from sae_training.activations_store import ActivationsStore -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sparse_autoencoder import SparseAutoencoder @torch.no_grad() diff --git a/sae_training/geometric_median.py b/sae_lens/training/geometric_median.py similarity index 100% rename from sae_training/geometric_median.py rename to sae_lens/training/geometric_median.py diff --git a/sae_training/lm_runner.py b/sae_lens/training/lm_runner.py similarity index 80% rename from sae_training/lm_runner.py rename to sae_lens/training/lm_runner.py index 334f1d45..daeeb02e 100644 --- a/sae_training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -1,11 +1,11 @@ from typing import Any, cast import wandb -from sae_training.config import LanguageModelSAERunnerConfig +from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader -# from sae_training.activation_store import ActivationStore -from sae_training.train_sae_on_language_model import train_sae_on_language_model -from sae_training.utils import LMSparseAutoencoderSessionloader +# from sae_lens.training.activation_store import ActivationStore +from sae_lens.training.train_sae_on_language_model import train_sae_on_language_model def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): diff --git a/sae_training/optim.py b/sae_lens/training/optim.py similarity index 100% rename from sae_training/optim.py rename to sae_lens/training/optim.py diff --git a/sae_training/sae_group.py b/sae_lens/training/sae_group.py similarity index 82% rename from sae_training/sae_group.py rename to sae_lens/training/sae_group.py index 568367f2..14721906 100644 --- a/sae_training/sae_group.py +++ b/sae_lens/training/sae_group.py @@ -3,11 +3,13 @@ import os import pickle from itertools import product +from types import SimpleNamespace from typing import Any, Iterator import torch -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.utils import BackwardsCompatibleUnpickler class SAEGroup: @@ -54,7 +56,7 @@ def to(self, device: torch.device | str): ae.to(device) @classmethod - def load_from_pretrained(cls, path: str): + def load_from_pretrained(cls, path: str) -> Any: """ Load function for the model. Loads the model's state_dict and the config used to train it. This method can be called directly on the class, without needing an instance. @@ -67,11 +69,25 @@ def load_from_pretrained(cls, path: str): # Load the state dictionary if path.endswith(".pt"): try: - if torch.backends.mps.is_available(): - group = torch.load(path, map_location="mps") - group["cfg"].device = "mps" + # this is hacky, but can't figure out how else to get torch to use our custom unpickler + fake_pickle = SimpleNamespace() + fake_pickle.Unpickler = BackwardsCompatibleUnpickler + fake_pickle.__name__ = pickle.__name__ + + if torch.cuda.is_available(): + group = torch.load( + path, + pickle_module=fake_pickle, + ) else: - group = torch.load(path) + map_loc = "mps" if torch.backends.mps.is_available() else "cpu" + group = torch.load( + path, pickle_module=fake_pickle, map_location=map_loc + ) + if isinstance(group, dict): + group["cfg"].device = map_loc + else: + group.cfg.device = map_loc except Exception as e: raise IOError(f"Error loading the state dictionary from .pt file: {e}") diff --git a/sae_training/utils.py b/sae_lens/training/session_loader.py similarity index 69% rename from sae_training/utils.py rename to sae_lens/training/session_loader.py index f407b92a..0aaf3b6d 100644 --- a/sae_training/utils.py +++ b/sae_lens/training/session_loader.py @@ -1,11 +1,10 @@ from typing import Any, Tuple -import torch from transformer_lens import HookedTransformer -from sae_training.activations_store import ActivationsStore -from sae_training.sae_group import SAEGroup -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sae_group import SAEGroup +from sae_lens.training.sparse_autoencoder import SparseAutoencoder class LMSparseAutoencoderSessionloader: @@ -98,32 +97,3 @@ def get_activations_loader(self, cfg: Any, model: HookedTransformer): ) return activations_loader - - -def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int]): - """ - Shuffles two buffers on disk. - """ - assert ( - buffer_idx_range[0] < buffer_idx_range[1] - 1 - ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1" - - buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item() - buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item() - while buffer_idx1 == buffer_idx2: # Make sure they're not the same - buffer_idx2 = torch.randint( - buffer_idx_range[0], buffer_idx_range[1], (1,) - ).item() - - buffer1 = torch.load(f"{datapath}/{buffer_idx1}.pt") - buffer2 = torch.load(f"{datapath}/{buffer_idx2}.pt") - joint_buffer = torch.cat([buffer1, buffer2]) - - # Shuffle them - joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])] - shuffled_buffer1 = joint_buffer[: buffer1.shape[0]] - shuffled_buffer2 = joint_buffer[buffer1.shape[0] :] - - # Save them back - torch.save(shuffled_buffer1, f"{datapath}/{buffer_idx1}.pt") - torch.save(shuffled_buffer2, f"{datapath}/{buffer_idx2}.pt") diff --git a/sae_training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py similarity index 99% rename from sae_training/sparse_autoencoder.py rename to sae_lens/training/sparse_autoencoder.py index 3a15d60c..5a805fbd 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -12,7 +12,7 @@ from torch import nn from transformer_lens.hook_points import HookedRootModule, HookPoint -from sae_training.config import LanguageModelSAERunnerConfig +from sae_lens.training.config import LanguageModelSAERunnerConfig class ForwardOutput(NamedTuple): diff --git a/sae_training/toy_model_runner.py b/sae_lens/training/toy_model_runner.py similarity index 92% rename from sae_training/toy_model_runner.py rename to sae_lens/training/toy_model_runner.py index de13dd13..29b436a6 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_lens/training/toy_model_runner.py @@ -5,10 +5,10 @@ import torch import wandb -from sae_training.sparse_autoencoder import SparseAutoencoder -from sae_training.toy_models import Config as ToyConfig -from sae_training.toy_models import Model as ToyModel -from sae_training.train_sae_on_toy_model import train_toy_sae +from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.toy_models import Config as ToyConfig +from sae_lens.training.toy_models import Model as ToyModel +from sae_lens.training.train_sae_on_toy_model import train_toy_sae @dataclass diff --git a/sae_training/toy_models.py b/sae_lens/training/toy_models.py similarity index 100% rename from sae_training/toy_models.py rename to sae_lens/training/toy_models.py diff --git a/sae_training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py similarity index 97% rename from sae_training/train_sae_on_language_model.py rename to sae_lens/training/train_sae_on_language_model.py index 00e5113f..b2c8b112 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -8,12 +8,12 @@ from transformer_lens import HookedTransformer import wandb -from sae_training.activations_store import ActivationsStore -from sae_training.evals import run_evals -from sae_training.geometric_median import compute_geometric_median -from sae_training.optim import get_scheduler -from sae_training.sae_group import SAEGroup -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.evals import run_evals +from sae_lens.training.geometric_median import compute_geometric_median +from sae_lens.training.optim import get_scheduler +from sae_lens.training.sae_group import SAEGroup +from sae_lens.training.sparse_autoencoder import SparseAutoencoder @dataclass diff --git a/sae_training/train_sae_on_toy_model.py b/sae_lens/training/train_sae_on_toy_model.py similarity index 98% rename from sae_training/train_sae_on_toy_model.py rename to sae_lens/training/train_sae_on_toy_model.py index 749a6ae5..e5227da9 100644 --- a/sae_training/train_sae_on_toy_model.py +++ b/sae_lens/training/train_sae_on_toy_model.py @@ -5,7 +5,7 @@ from tqdm import tqdm import wandb -from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_lens.training.sparse_autoencoder import SparseAutoencoder def train_toy_sae( diff --git a/sae_lens/training/utils.py b/sae_lens/training/utils.py new file mode 100644 index 00000000..b9dcf515 --- /dev/null +++ b/sae_lens/training/utils.py @@ -0,0 +1,43 @@ +import pickle +from typing import Tuple + +import torch + + +class BackwardsCompatibleUnpickler(pickle.Unpickler): + """ + An Unpickler that can load files saved before the "sae_lens" package namechange + """ + + def find_class(self, module: str, name: str): + module = module.replace("sae_training", "sae_lens.training") + return super().find_class(module, name) + + +def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int]): + """ + Shuffles two buffers on disk. + """ + assert ( + buffer_idx_range[0] < buffer_idx_range[1] - 1 + ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1" + + buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item() + buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item() + while buffer_idx1 == buffer_idx2: # Make sure they're not the same + buffer_idx2 = torch.randint( + buffer_idx_range[0], buffer_idx_range[1], (1,) + ).item() + + buffer1 = torch.load(f"{datapath}/{buffer_idx1}.pt") + buffer2 = torch.load(f"{datapath}/{buffer_idx2}.pt") + joint_buffer = torch.cat([buffer1, buffer2]) + + # Shuffle them + joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])] + shuffled_buffer1 = joint_buffer[: buffer1.shape[0]] + shuffled_buffer2 = joint_buffer[buffer1.shape[0] :] + + # Save them back + torch.save(shuffled_buffer1, f"{datapath}/{buffer_idx1}.pt") + torch.save(shuffled_buffer2, f"{datapath}/{buffer_idx2}.pt") diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 7db4e82f..3ce81234 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -1,912 +1,912 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Notebook with Example Config for Different Models / Hooks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: mps\n" - ] - } - ], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "if torch.cuda.is_available():\n", - " device = \"cuda\"\n", - "elif torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "else:\n", - " device = \"cpu\"\n", - "\n", - "print(\"Using device:\", device)\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Gelu-2L\n", - "\n", - "An example of a toy language model we're able to train on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### MLP Out" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gelu-2l\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=[16, 32, 64],\n", - " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " l1_coefficient=0.00016,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 100,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=5000,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-4,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GPT2 - Small" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Residual Stream" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "layer = 3\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", - " hook_point_layer=layer,\n", - " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=32, # determines the dimension of the SAE.\n", - " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0004,\n", - " l1_coefficient=0.00008,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " lr_warm_up_steps=5000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=2500,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=10,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70-M" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "import cProfile\n", - "\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=64,\n", - " # Training Parameters\n", - " lr=3e-4,\n", - " l1_coefficient=4e-5,\n", - " train_batch_size=8192,\n", - " context_size=128,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=10_000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=1_000_000 * 800,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_window=2000, # Doesn't currently matter.\n", - " dead_feature_window=40000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=20,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70M Hook Q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"../\")\n", - "\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.2.attn.hook_q\",\n", - " hook_point_layer=2,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " l1_coefficient=0.003,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=1000, # about 4 million tokens.\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 1500,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_method=\"anthropic\",\n", - " feature_sampling_window=1000, # doesn't do anything currently.\n", - " feature_reinit_scale=0.2,\n", - " resample_batches=8,\n", - " dead_feature_window=60000,\n", - " dead_feature_threshold=1e-5,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " n_checkpoints=15,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tiny Stories" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## MLP Out" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.mlp.hook_post\",\n", - " hook_point_layer=1,\n", - " d_in=256,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", - " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Hook Z\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notebook with Example Config for Different Models / Hooks" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", - "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", - "Moving model to device: mps\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: mps\n" + ] + } + ], + "source": [ + "import torch\n", + "import os\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "print(\"Using device:\", device)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", - "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", - "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gelu-2L\n", + "\n", + "An example of a toy language model we're able to train on." + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjbloom\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### MLP Out" + ] }, { - "data": { - "text/html": [ - "wandb version 0.16.5 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"gelu-2l\",\n", + " hook_point=\"blocks.0.hook_mlp_out\",\n", + " hook_point_layer=0,\n", + " d_in=512,\n", + " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", + " is_dataset_tokenized=True,\n", + " # SAE Parameters\n", + " expansion_factor=[16, 32, 64],\n", + " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", + " # Training Parameters\n", + " lr=0.0012,\n", + " lr_scheduler_name=\"constantwithwarmup\",\n", + " l1_coefficient=0.00016,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 100,\n", + " store_batch_size=32,\n", + " # Resampling protocol\n", + " use_ghost_grads=True,\n", + " feature_sampling_window=5000,\n", + " dead_feature_window=5000,\n", + " dead_feature_threshold=1e-4,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", + " wandb_log_frequency=10,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "Tracking run with wandb version 0.16.3" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GPT2 - Small" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Residual Stream" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "layer = 3\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"gpt2-small\",\n", + " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", + " hook_point_layer=layer,\n", + " d_in=768,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", + " is_dataset_tokenized=False,\n", + " # SAE Parameters\n", + " expansion_factor=32, # determines the dimension of the SAE.\n", + " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", + " # Training Parameters\n", + " lr=0.0004,\n", + " l1_coefficient=0.00008,\n", + " lr_scheduler_name=\"constantwithwarmup\",\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " lr_warm_up_steps=5000,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", + " store_batch_size=32,\n", + " # Resampling protocol\n", + " use_ghost_grads=True,\n", + " feature_sampling_window=2500,\n", + " dead_feature_window=5000,\n", + " dead_feature_threshold=1e-8,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", + " wandb_entity=None,\n", + " wandb_log_frequency=100,\n", + " # Misc\n", + " device=\"cuda\",\n", + " seed=42,\n", + " n_checkpoints=10,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - " View project at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pythia 70-M" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "import cProfile\n", + "\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"pythia-70m-deduped\",\n", + " hook_point=\"blocks.0.hook_mlp_out\",\n", + " hook_point_layer=0,\n", + " d_in=512,\n", + " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", + " is_dataset_tokenized=False,\n", + " # SAE Parameters\n", + " expansion_factor=64,\n", + " # Training Parameters\n", + " lr=3e-4,\n", + " l1_coefficient=4e-5,\n", + " train_batch_size=8192,\n", + " context_size=128,\n", + " lr_scheduler_name=\"constantwithwarmup\",\n", + " lr_warm_up_steps=10_000,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64,\n", + " total_training_tokens=1_000_000 * 800,\n", + " store_batch_size=32,\n", + " # Resampling protocol\n", + " feature_sampling_window=2000, # Doesn't currently matter.\n", + " dead_feature_window=40000,\n", + " dead_feature_threshold=1e-8,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_entity=None,\n", + " wandb_log_frequency=20,\n", + " # Misc\n", + " device=\"cuda\",\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", - "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", - "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", - "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", - "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pythia 70M Hook Q" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "import sys\n", + "\n", + "sys.path.append(\"../\")\n", + "\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"pythia-70m-deduped\",\n", + " hook_point=\"blocks.2.attn.hook_q\",\n", + " hook_point_layer=2,\n", + " hook_point_head_index=7,\n", + " d_in=64,\n", + " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", + " is_dataset_tokenized=False,\n", + " # SAE Parameters\n", + " expansion_factor=16,\n", + " # Training Parameters\n", + " lr=0.0012,\n", + " l1_coefficient=0.003,\n", + " lr_scheduler_name=\"constantwithwarmup\",\n", + " lr_warm_up_steps=1000, # about 4 million tokens.\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 1500,\n", + " store_batch_size=32,\n", + " # Resampling protocol\n", + " feature_sampling_method=\"anthropic\",\n", + " feature_sampling_window=1000, # doesn't do anything currently.\n", + " feature_reinit_scale=0.2,\n", + " resample_batches=8,\n", + " dead_feature_window=60000,\n", + " dead_feature_threshold=1e-5,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", + " wandb_entity=None,\n", + " wandb_log_frequency=100,\n", + " # Misc\n", + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=15,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1dffd84a387d4cf48100fbe143287481", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tiny Stories" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MLP Out" + ] }, { - "data": { - "text/html": [ - "\n", - "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" - ], - "text/plain": [ - "" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "if device == \"cpu\" and torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"tiny-stories-1M\",\n", + " hook_point=\"blocks.1.mlp.hook_post\",\n", + " hook_point_layer=1,\n", + " d_in=256,\n", + " # dataset_path=\"roneneldan/TinyStories\",\n", + " # is_dataset_tokenized=False,\n", + " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " # SAE Parameters\n", + " expansion_factor=16,\n", + " # Training Parameters\n", + " lr=1e-4,\n", + " lp_norm=1.0,\n", + " l1_coefficient=2e-4,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 20,\n", + " store_batch_size=32,\n", + " feature_sampling_window=500, # So we see the histograms.\n", + " dead_feature_window=250,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_log_frequency=10,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hook Z\n", + "\n" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", + "n_tokens_per_buffer (millions): 0.524288\n", + "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", + "n_tokens_per_feature_sampling_window (millions): 262.144\n", + "n_tokens_per_dead_feature_window (millions): 131.072\n", + "We will reset the sparsity calculation 9 times.\n", + "Number tokens in sparsity calculation window: 2.05e+06\n", + "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", + "Moving model to device: mps\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", + "n_tokens_per_buffer (millions): 0.524288\n", + "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", + "n_tokens_per_feature_sampling_window (millions): 262.144\n", + "n_tokens_per_dead_feature_window (millions): 131.072\n", + "We will reset the sparsity calculation 9 times.\n", + "Number tokens in sparsity calculation window: 2.05e+06\n", + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", + "n_tokens_per_buffer (millions): 0.524288\n", + "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", + "n_tokens_per_feature_sampling_window (millions): 262.144\n", + "n_tokens_per_dead_feature_window (millions): 131.072\n", + "We will reset the sparsity calculation 9 times.\n", + "Number tokens in sparsity calculation window: 2.05e+06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjbloom\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.16.5 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", + "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", + "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", + "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1dffd84a387d4cf48100fbe143287481", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "if device == \"cpu\" and torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"tiny-stories-1M\",\n", + " hook_point=\"blocks.1.attn.hook_z\",\n", + " hook_point_layer=1,\n", + " d_in=64,\n", + " # dataset_path=\"roneneldan/TinyStories\",\n", + " # is_dataset_tokenized=False,\n", + " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " # SAE Parameters\n", + " expansion_factor=16,\n", + " # Training Parameters\n", + " lr=1e-4,\n", + " lp_norm=1.0,\n", + " l1_coefficient=2e-4,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 20,\n", + " store_batch_size=32,\n", + " feature_sampling_window=500, # So we see the histograms.\n", + " dead_feature_window=250,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_log_frequency=10,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Toy Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sae_lens.training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner\n", + "\n", + "\n", + "cfg = SAEToyModelRunnerConfig(\n", + " # Model Details\n", + " n_features=200,\n", + " n_hidden=5,\n", + " n_correlated_pairs=0,\n", + " n_anticorrelated_pairs=0,\n", + " feature_probability=0.025,\n", + " model_training_steps=10_000,\n", + " # SAE Parameters\n", + " d_sae=240,\n", + " l1_coefficient=0.001,\n", + " # SAE Train Config\n", + " train_batch_size=1028,\n", + " feature_sampling_window=3_000,\n", + " dead_feature_window=1_000,\n", + " feature_reinit_scale=0.5,\n", + " total_training_tokens=4096 * 300,\n", + " # Other parameters\n", + " log_to_wandb=True,\n", + " wandb_project=\"sae-training-test\",\n", + " wandb_log_frequency=5,\n", + " device=\"mps\",\n", + ")\n", + "\n", + "trained_sae = toy_model_sae_runner(cfg)\n", + "\n", + "assert trained_sae is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run caching of activations to disk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", + "\n", + "from sae_lens.training.config import CacheActivationsRunnerConfig\n", + "from sae_lens.training.cache_activations_runner import cache_activations_runner\n", + "\n", + "cfg = CacheActivationsRunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"gpt2-small\",\n", + " hook_point=\"blocks.10.attn.hook_q\",\n", + " hook_point_layer=10,\n", + " hook_point_head_index=7,\n", + " d_in=64,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", + " is_dataset_tokenized=False,\n", + " cached_activations_path=\"../activations/\",\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=16,\n", + " total_training_tokens=500_000_000,\n", + " store_batch_size=32,\n", + " # Activation caching shuffle parameters\n", + " n_shuffles_final=16,\n", + " # Misc\n", + " device=\"mps\",\n", + " seed=42,\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "cache_activations_runner(cfg)" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", - " lambda data: self._console_raw_callback(\"stderr\", data),\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train an SAE using the cached activations stored on disk\n", + "Pass `use_cached_activations=True` into the config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", + "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", + "from sae_lens.training.lm_runner import language_model_sae_runner\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"gpt2-small\",\n", + " hook_point=\"blocks.10.hook_resid_pre\",\n", + " hook_point_layer=11,\n", + " d_in=768,\n", + " dataset_path=\"Skylion007/openwebtext\",\n", + " is_dataset_tokenized=False,\n", + " use_cached_activations=True,\n", + " # SAE Parameters\n", + " expansion_factor=64, # determines the dimension of the SAE.\n", + " # Training Parameters\n", + " lr=1e-5,\n", + " l1_coefficient=5e-4,\n", + " lr_scheduler_name=None,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64,\n", + " total_training_tokens=200_000,\n", + " store_batch_size=32,\n", + " # Resampling protocol\n", + " feature_sampling_method=\"l2\",\n", + " feature_sampling_window=1000,\n", + " feature_reinit_scale=0.2,\n", + " dead_feature_window=5000,\n", + " dead_feature_threshold=1e-7,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_gpt2_small\",\n", + " wandb_entity=None,\n", + " wandb_log_frequency=50,\n", + " # Misc\n", + " device=\"mps\",\n", + " seed=42,\n", + " n_checkpoints=5,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mats_sae_training", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" } - ], - "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.attn.hook_z\",\n", - " hook_point_layer=1,\n", - " d_in=64,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", - " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Toy Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner\n", - "\n", - "\n", - "cfg = SAEToyModelRunnerConfig(\n", - " # Model Details\n", - " n_features=200,\n", - " n_hidden=5,\n", - " n_correlated_pairs=0,\n", - " n_anticorrelated_pairs=0,\n", - " feature_probability=0.025,\n", - " model_training_steps=10_000,\n", - " # SAE Parameters\n", - " d_sae=240,\n", - " l1_coefficient=0.001,\n", - " # SAE Train Config\n", - " train_batch_size=1028,\n", - " feature_sampling_window=3_000,\n", - " dead_feature_window=1_000,\n", - " feature_reinit_scale=0.5,\n", - " total_training_tokens=4096 * 300,\n", - " # Other parameters\n", - " log_to_wandb=True,\n", - " wandb_project=\"sae-training-test\",\n", - " wandb_log_frequency=5,\n", - " device=\"mps\",\n", - ")\n", - "\n", - "trained_sae = toy_model_sae_runner(cfg)\n", - "\n", - "assert trained_sae is not None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run caching of activations to disk" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "\n", - "from sae_training.config import CacheActivationsRunnerConfig\n", - "from sae_training.cache_activations_runner import cache_activations_runner\n", - "\n", - "cfg = CacheActivationsRunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.attn.hook_q\",\n", - " hook_point_layer=10,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " cached_activations_path=\"../activations/\",\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=16,\n", - " total_training_tokens=500_000_000,\n", - " store_batch_size=32,\n", - " # Activation caching shuffle parameters\n", - " n_shuffles_final=16,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "cache_activations_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train an SAE using the cached activations stored on disk\n", - "Pass `use_cached_activations=True` into the config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "from sae_training.config import LanguageModelSAERunnerConfig\n", - "from sae_training.lm_runner import language_model_sae_runner\n", - "\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.hook_resid_pre\",\n", - " hook_point_layer=11,\n", - " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " use_cached_activations=True,\n", - " # SAE Parameters\n", - " expansion_factor=64, # determines the dimension of the SAE.\n", - " # Training Parameters\n", - " lr=1e-5,\n", - " l1_coefficient=5e-4,\n", - " lr_scheduler_name=None,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=200_000,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_method=\"l2\",\n", - " feature_sampling_window=1000,\n", - " feature_reinit_scale=0.2,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-7,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_gpt2_small\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=50,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " n_checkpoints=5,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mats_sae_training", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index ade45557..e0030ce5 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -1,7 +1,7 @@ import torch -from sae_training.config import LanguageModelSAERunnerConfig -from sae_training.lm_runner import language_model_sae_runner +from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.lm_runner import language_model_sae_runner def test_language_model_sae_runner(): diff --git a/tests/benchmark/test_toy_model_sae_runner.py b/tests/benchmark/test_toy_model_sae_runner.py index ed904d9f..dd2a6252 100644 --- a/tests/benchmark/test_toy_model_sae_runner.py +++ b/tests/benchmark/test_toy_model_sae_runner.py @@ -1,6 +1,9 @@ import torch -from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner +from sae_lens.training.toy_model_runner import ( + SAEToyModelRunnerConfig, + toy_model_sae_runner, +) # @pytest.mark.skip(reason="I (joseph) broke this at some point, on my to do list to fix.") diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 200fe225..5d06f2ea 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -2,7 +2,7 @@ import torch -from sae_training.config import LanguageModelSAERunnerConfig +from sae_lens.training.config import LanguageModelSAERunnerConfig TINYSTORIES_MODEL = "tiny-stories-1M" TINYSTORIES_DATASET = "roneneldan/TinyStories" diff --git a/tests/unit/test_activations_store.py b/tests/unit/training/test_activations_store.py similarity index 99% rename from tests/unit/test_activations_store.py rename to tests/unit/training/test_activations_store.py index b746c37f..423f9225 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -7,7 +7,7 @@ from datasets import Dataset, IterableDataset from transformer_lens import HookedTransformer -from sae_training.activations_store import ActivationsStore +from sae_lens.training.activations_store import ActivationsStore from tests.unit.helpers import build_sae_cfg diff --git a/tests/unit/test_optim.py b/tests/unit/training/test_optim.py similarity index 99% rename from tests/unit/test_optim.py rename to tests/unit/training/test_optim.py index 2c89a96a..9fbe83c4 100644 --- a/tests/unit/test_optim.py +++ b/tests/unit/training/test_optim.py @@ -9,7 +9,7 @@ LRScheduler, ) -from sae_training.optim import get_scheduler +from sae_lens.training.optim import get_scheduler LR = 0.1 diff --git a/tests/unit/test_utils.py b/tests/unit/training/test_session_loader.py similarity index 78% rename from tests/unit/test_utils.py rename to tests/unit/training/test_session_loader.py index c6b2772c..8969237e 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/training/test_session_loader.py @@ -3,12 +3,14 @@ import pytest import torch +from huggingface_hub import hf_hub_download from transformer_lens import HookedTransformer -from sae_training.activations_store import ActivationsStore -from sae_training.config import LanguageModelSAERunnerConfig -from sae_training.sparse_autoencoder import SparseAutoencoder -from sae_training.utils import LMSparseAutoencoderSessionloader +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.sae_group import SAEGroup +from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader +from sae_lens.training.sparse_autoencoder import SparseAutoencoder TEST_MODEL = "tiny-stories-1M" TEST_DATASET = "roneneldan/TinyStories" @@ -117,3 +119,21 @@ def test_LMSparseAutoencoderSessionloader_load_session_from_trained(cfg: Any): new_parameters = dict(new_sae_group.autoencoders[0].named_parameters()) for name, param in sae_group.autoencoders[0].named_parameters(): assert torch.allclose(param, new_parameters[name]) + + +def test_load_pretrained_sae_from_huggingface(): + layer = 8 # pick a layer you want. + REPO_ID = "jbloom/GPT2-Small-SAEs" + FILENAME = ( + f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt" + ) + path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) + model, sae_group, activation_store = ( + LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path) + ) + assert isinstance(model, HookedTransformer) + assert isinstance(sae_group, SAEGroup) + assert isinstance(activation_store, ActivationsStore) + assert len(sae_group) == 1 + assert sae_group.cfg.hook_point_layer == layer + assert sae_group.cfg.model_name == "gpt2-small" diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/training/test_sparse_autoencoder.py similarity index 98% rename from tests/unit/test_sparse_autoencoder.py rename to tests/unit/training/test_sparse_autoencoder.py index 9cf059f5..b1db88eb 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/training/test_sparse_autoencoder.py @@ -7,8 +7,8 @@ import torch from transformer_lens import HookedTransformer -from sae_training.config import LanguageModelSAERunnerConfig -from sae_training.sparse_autoencoder import ( +from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.sparse_autoencoder import ( SparseAutoencoder, _per_item_mse_loss_with_target_norm, ) diff --git a/tests/unit/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py similarity index 97% rename from tests/unit/test_train_sae_on_language_model.py rename to tests/unit/training/test_train_sae_on_language_model.py index 3f5b4e58..13c4ca8a 100644 --- a/tests/unit/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -8,11 +8,11 @@ from torch import Tensor from transformer_lens import HookedTransformer -from sae_training.activations_store import ActivationsStore -from sae_training.optim import get_scheduler -from sae_training.sae_group import SAEGroup -from sae_training.sparse_autoencoder import ForwardOutput, SparseAutoencoder -from sae_training.train_sae_on_language_model import ( +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.optim import get_scheduler +from sae_lens.training.sae_group import SAEGroup +from sae_lens.training.sparse_autoencoder import ForwardOutput, SparseAutoencoder +from sae_lens.training.train_sae_on_language_model import ( SAETrainContext, TrainStepOutput, _build_train_step_log_dict, diff --git a/tutorials/evaluating_your_sae.ipynb b/tutorials/evaluating_your_sae.ipynb index 7f71efc8..f563007d 100644 --- a/tutorials/evaluating_your_sae.ipynb +++ b/tutorials/evaluating_your_sae.ipynb @@ -1,370 +1,370 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Evaluating your SAE" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluating your SAE" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "import json\n", + "import plotly.express as px\n", + "from transformer_lens import utils\n", + "from datasets import load_dataset\n", + "from typing import Dict\n", + "from pathlib import Path\n", + "\n", + "from functools import partial\n", + "\n", + "sys.path.append(\"..\")\n", + "\n", + "from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader\n", + "from sae_lens.analysis.visualizer.data_fns import get_feature_data, FeatureData\n", + "\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load your Autoencoder\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start by downloading them from huggingface\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", + "\n", + "\n", + "layer = 8 # any layer from 0 - 11 works here\n", + "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"\n", + "\n", + "# this is great because if you've already downloaded the SAE it won't download it twice!\n", + "path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We can then load the SAE, dataset and model using the session loader\n", + "model, sparse_autoencoders, activation_store = (\n", + " LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i, sae in enumerate(sparse_autoencoders):\n", + " hyp = sae.cfg\n", + " print(\n", + " f\"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pick which sae you wnat to evaluate. Default is 0\n", + "sparse_autoencoder = sparse_autoencoders[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test the Autoencoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### L0 Test and Reconstruction Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sparse_autoencoder.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", + "with torch.no_grad():\n", + " batch_tokens = activation_store.get_batch_tokens()\n", + " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", + " sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(\n", + " cache[sparse_autoencoder.cfg.hook_point]\n", + " )\n", + " del cache\n", + "\n", + " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", + " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", + " print(\"average l0\", l0.mean().item())\n", + " px.histogram(l0.flatten().cpu().numpy()).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# next we want to do a reconstruction test.\n", + "def reconstr_hook(activation, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(activation, hook):\n", + " return torch.zeros_like(activation)\n", + "\n", + "\n", + "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " fwd_hooks=[\n", + " (\n", + " utils.get_act_name(\"resid_pre\", 10),\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(utils.get_act_name(\"resid_pre\", 10), zero_abl_hook)],\n", + " ).item(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specific Capability Test\n", + "\n", + "Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "example_prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", + "example_answer = \" Mary\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)\n", + "\n", + "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", + "tokens = model.to_tokens(example_prompt)\n", + "sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(\n", + " cache[sparse_autoencoder.cfg.hook_point]\n", + ")\n", + "\n", + "\n", + "def reconstr_hook(activations, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(mlp_out, hook):\n", + " return torch.zeros_like(mlp_out)\n", + "\n", + "\n", + "hook_point = sparse_autoencoder.cfg.hook_point\n", + "\n", + "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " fwd_hooks=[\n", + " (\n", + " hook_point,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(hook_point, zero_abl_hook)],\n", + " ).item(),\n", + ")\n", + "\n", + "\n", + "with model.hooks(\n", + " fwd_hooks=[\n", + " (\n", + " hook_point,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ]\n", + "):\n", + " utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generating Feature Interfaces" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)\n", + "px.bar(x=[str(i) for i in inds], y=vals).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_dict = model.tokenizer.vocab\n", + "vocab_dict = {\n", + " v: k.replace(\"Ġ\", \" \").replace(\"\\n\", \"\\\\n\") for k, v in vocab_dict.items()\n", + "}\n", + "\n", + "vocab_dict_filepath = Path(os.getcwd()) / \"vocab_dict.json\"\n", + "if not vocab_dict_filepath.exists():\n", + " with open(vocab_dict_filepath, \"w\") as f:\n", + " json.dump(vocab_dict, f)\n", + "\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "data = load_dataset(\n", + " \"NeelNanda/c4-code-20k\", split=\"train\"\n", + ") # currently use this dataset to avoid deal with tokenization while streaming\n", + "tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)\n", + "tokenized_data = tokenized_data.shuffle(42)\n", + "all_tokens = tokenized_data[\"tokens\"]\n", + "\n", + "\n", + "# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to\n", + "# make the entire sequence indexing parallelized, but that's possibly not worth it right now.\n", + "\n", + "max_batch_size = 512\n", + "total_batch_size = 4096 * 5\n", + "feature_idx = list(inds.flatten().cpu().numpy())\n", + "# max_batch_size = 512\n", + "# total_batch_size = 16384\n", + "# feature_idx = list(range(1000))\n", + "\n", + "tokens = all_tokens[:total_batch_size]\n", + "\n", + "feature_data: Dict[int, FeatureData] = get_feature_data(\n", + " encoder=sparse_autoencoder,\n", + " # encoder_B=sparse_autoencoder,\n", + " model=model,\n", + " hook_point=sparse_autoencoder.cfg.hook_point,\n", + " hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,\n", + " tokens=tokens,\n", + " feature_idx=feature_idx,\n", + " max_batch_size=max_batch_size,\n", + " left_hand_k=3,\n", + " buffer=(5, 5),\n", + " n_groups=10,\n", + " first_group_size=20,\n", + " other_groups_size=5,\n", + " verbose=True,\n", + ")\n", + "\n", + "\n", + "for test_idx in list(inds.flatten().cpu().numpy()):\n", + " html_str = feature_data[test_idx].get_all_html()\n", + " with open(f\"data_{test_idx:04}.html\", \"w\") as f:\n", + " f.write(html_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This will produce a number of html files which each contain a dashboard showing feature activation on the sample data. It currently doesn't process that much data so it isn't that useful. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mats_sae_training", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set Up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "import torch\n", - "import json\n", - "import plotly.express as px\n", - "from transformer_lens import utils\n", - "from datasets import load_dataset\n", - "from typing import Dict\n", - "from pathlib import Path\n", - "\n", - "from functools import partial\n", - "\n", - "sys.path.append(\"..\")\n", - "\n", - "from sae_training.utils import LMSparseAutoencoderSessionloader\n", - "from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData\n", - "\n", - "if torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "else:\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - "torch.set_grad_enabled(False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load your Autoencoder\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Start by downloading them from huggingface\n", - "from huggingface_hub import hf_hub_download\n", - "\n", - "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "\n", - "\n", - "layer = 8 # any layer from 0 - 11 works here\n", - "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"\n", - "\n", - "# this is great because if you've already downloaded the SAE it won't download it twice!\n", - "path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# We can then load the SAE, dataset and model using the session loader\n", - "model, sparse_autoencoders, activation_store = (\n", - " LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, sae in enumerate(sparse_autoencoders):\n", - " hyp = sae.cfg\n", - " print(\n", - " f\"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# pick which sae you wnat to evaluate. Default is 0\n", - "sparse_autoencoder = sparse_autoencoders[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Test the Autoencoder" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### L0 Test and Reconstruction Test" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sparse_autoencoder.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", - "with torch.no_grad():\n", - " batch_tokens = activation_store.get_batch_tokens()\n", - " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", - " sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(\n", - " cache[sparse_autoencoder.cfg.hook_point]\n", - " )\n", - " del cache\n", - "\n", - " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", - " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", - " print(\"average l0\", l0.mean().item())\n", - " px.histogram(l0.flatten().cpu().numpy()).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# next we want to do a reconstruction test.\n", - "def reconstr_hook(activation, hook, sae_out):\n", - " return sae_out\n", - "\n", - "\n", - "def zero_abl_hook(activation, hook):\n", - " return torch.zeros_like(activation)\n", - "\n", - "\n", - "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", - "print(\n", - " \"reconstr\",\n", - " model.run_with_hooks(\n", - " batch_tokens,\n", - " fwd_hooks=[\n", - " (\n", - " utils.get_act_name(\"resid_pre\", 10),\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ],\n", - " return_type=\"loss\",\n", - " ).item(),\n", - ")\n", - "print(\n", - " \"Zero\",\n", - " model.run_with_hooks(\n", - " batch_tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[(utils.get_act_name(\"resid_pre\", 10), zero_abl_hook)],\n", - " ).item(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Specific Capability Test\n", - "\n", - "Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "example_prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", - "example_answer = \" Mary\"\n", - "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)\n", - "\n", - "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", - "tokens = model.to_tokens(example_prompt)\n", - "sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(\n", - " cache[sparse_autoencoder.cfg.hook_point]\n", - ")\n", - "\n", - "\n", - "def reconstr_hook(activations, hook, sae_out):\n", - " return sae_out\n", - "\n", - "\n", - "def zero_abl_hook(mlp_out, hook):\n", - " return torch.zeros_like(mlp_out)\n", - "\n", - "\n", - "hook_point = sparse_autoencoder.cfg.hook_point\n", - "\n", - "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", - "print(\n", - " \"reconstr\",\n", - " model.run_with_hooks(\n", - " tokens,\n", - " fwd_hooks=[\n", - " (\n", - " hook_point,\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ],\n", - " return_type=\"loss\",\n", - " ).item(),\n", - ")\n", - "print(\n", - " \"Zero\",\n", - " model.run_with_hooks(\n", - " tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[(hook_point, zero_abl_hook)],\n", - " ).item(),\n", - ")\n", - "\n", - "\n", - "with model.hooks(\n", - " fwd_hooks=[\n", - " (\n", - " hook_point,\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ]\n", - "):\n", - " utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Generating Feature Interfaces" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)\n", - "px.bar(x=[str(i) for i in inds], y=vals).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vocab_dict = model.tokenizer.vocab\n", - "vocab_dict = {\n", - " v: k.replace(\"Ġ\", \" \").replace(\"\\n\", \"\\\\n\") for k, v in vocab_dict.items()\n", - "}\n", - "\n", - "vocab_dict_filepath = Path(os.getcwd()) / \"vocab_dict.json\"\n", - "if not vocab_dict_filepath.exists():\n", - " with open(vocab_dict_filepath, \"w\") as f:\n", - " json.dump(vocab_dict, f)\n", - "\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "data = load_dataset(\n", - " \"NeelNanda/c4-code-20k\", split=\"train\"\n", - ") # currently use this dataset to avoid deal with tokenization while streaming\n", - "tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)\n", - "tokenized_data = tokenized_data.shuffle(42)\n", - "all_tokens = tokenized_data[\"tokens\"]\n", - "\n", - "\n", - "# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to\n", - "# make the entire sequence indexing parallelized, but that's possibly not worth it right now.\n", - "\n", - "max_batch_size = 512\n", - "total_batch_size = 4096 * 5\n", - "feature_idx = list(inds.flatten().cpu().numpy())\n", - "# max_batch_size = 512\n", - "# total_batch_size = 16384\n", - "# feature_idx = list(range(1000))\n", - "\n", - "tokens = all_tokens[:total_batch_size]\n", - "\n", - "feature_data: Dict[int, FeatureData] = get_feature_data(\n", - " encoder=sparse_autoencoder,\n", - " # encoder_B=sparse_autoencoder,\n", - " model=model,\n", - " hook_point=sparse_autoencoder.cfg.hook_point,\n", - " hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,\n", - " tokens=tokens,\n", - " feature_idx=feature_idx,\n", - " max_batch_size=max_batch_size,\n", - " left_hand_k=3,\n", - " buffer=(5, 5),\n", - " n_groups=10,\n", - " first_group_size=20,\n", - " other_groups_size=5,\n", - " verbose=True,\n", - ")\n", - "\n", - "\n", - "for test_idx in list(inds.flatten().cpu().numpy()):\n", - " html_str = feature_data[test_idx].get_all_html()\n", - " with open(f\"data_{test_idx:04}.html\", \"w\") as f:\n", - " f.write(html_str)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This will produce a number of html files which each contain a dashboard showing feature activation on the sample data. It currently doesn't process that much data so it isn't that useful. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mats_sae_training", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tutorials/generating_sae_dashboards.ipynb b/tutorials/generating_sae_dashboards.ipynb index 9470fb9c..8dbcd43a 100644 --- a/tutorials/generating_sae_dashboards.ipynb +++ b/tutorials/generating_sae_dashboards.ipynb @@ -1,136 +1,136 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Generating Dashboards\n", - "\n", - "We use Callum McDougall's `sae_viz` library for generating feature dashboards. \n", - "\n", - "We've written a runner that will wrap Callum's code and log artefacts to wandb / pick-up where it left off if needed." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generating Dashboards\n", + "\n", + "We use Callum McDougall's `sae_viz` library for generating feature dashboards. \n", + "\n", + "We've written a runner that will wrap Callum's code and log artefacts to wandb / pick-up where it left off if needed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import webbrowser\n", + "import os\n", + "import sys\n", + "from huggingface_hub import hf_hub_download\n", + "from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig\n", + "from sae_vis.utils_fns import get_device\n", + "from sae_lens.analysis.dashboard_runner import DashboardRunner\n", + "\n", + "device = get_device()\n", + "print(device)\n", + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use Runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "layer = 8\n", + "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", + "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"\n", + "path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", + "\n", + "obj = torch.load(path, map_location=device)\n", + "state_dict = obj[\"state_dict\"]\n", + "assert set(state_dict.keys()) == {\"W_enc\", \"b_enc\", \"W_dec\", \"b_dec\"}\n", + "\n", + "\n", + "# Since Callum's library has it's own autoencoder class, it's important to check\n", + "# that we don't diverge from it in the future. For now, it should be fine\n", + "# with the SAE above.\n", + "cfg = AutoEncoderConfig(\n", + " d_in=obj[\"cfg\"].d_in,\n", + " dict_mult=obj[\"cfg\"].expansion_factor,\n", + " device=device,\n", + ")\n", + "gpt2_sae = AutoEncoder(cfg)\n", + "gpt2_sae.load_state_dict(state_dict)\n", + "\n", + "\n", + "runner = DashboardRunner(\n", + " sae_path=path, # this will handle a local path.\n", + " dashboard_parent_folder=\"../feature_dashboards\",\n", + " init_session=True,\n", + " n_batches_to_sample_from=2\n", + " ** 12, # sampling more batches helps us get a more diverse text sample.\n", + " n_prompts_to_select=4096 * 6, # more prompts are important for sparser features.\n", + " n_features_at_a_time=128,\n", + " max_batch_size=256,\n", + " buffer_tokens=8,\n", + " use_wandb=False,\n", + " continue_existing_dashboard=True,\n", + ")\n", + "runner.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Dashboards" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_files = os.listdir(runner.dashboard_folder)\n", + "# pick 3 random feature files and open them in the web browser\n", + "for i in range(3):\n", + " feature_file = feature_files[i]\n", + " url = f\"file://{os.path.abspath(runner.dashboard_folder)}/{feature_file}\"\n", + " webbrowser.open(url)\n", + " print(url)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mats_sae_training", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set Up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import webbrowser\n", - "import os\n", - "import sys\n", - "from huggingface_hub import hf_hub_download\n", - "from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig\n", - "from sae_vis.utils_fns import get_device\n", - "from sae_analysis.dashboard_runner import DashboardRunner\n", - "\n", - "device = get_device()\n", - "print(device)\n", - "torch.set_grad_enabled(False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Use Runner" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "layer = 8\n", - "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"\n", - "path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", - "\n", - "obj = torch.load(path, map_location=device)\n", - "state_dict = obj[\"state_dict\"]\n", - "assert set(state_dict.keys()) == {\"W_enc\", \"b_enc\", \"W_dec\", \"b_dec\"}\n", - "\n", - "\n", - "# Since Callum's library has it's own autoencoder class, it's important to check\n", - "# that we don't diverge from it in the future. For now, it should be fine\n", - "# with the SAE above.\n", - "cfg = AutoEncoderConfig(\n", - " d_in=obj[\"cfg\"].d_in,\n", - " dict_mult=obj[\"cfg\"].expansion_factor,\n", - " device=device,\n", - ")\n", - "gpt2_sae = AutoEncoder(cfg)\n", - "gpt2_sae.load_state_dict(state_dict)\n", - "\n", - "\n", - "runner = DashboardRunner(\n", - " sae_path=path, # this will handle a local path.\n", - " dashboard_parent_folder=\"../feature_dashboards\",\n", - " init_session=True,\n", - " n_batches_to_sample_from=2\n", - " ** 12, # sampling more batches helps us get a more diverse text sample.\n", - " n_prompts_to_select=4096 * 6, # more prompts are important for sparser features.\n", - " n_features_at_a_time=128,\n", - " max_batch_size=256,\n", - " buffer_tokens=8,\n", - " use_wandb=False,\n", - " continue_existing_dashboard=True,\n", - ")\n", - "runner.run()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Visualize Dashboards" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "feature_files = os.listdir(runner.dashboard_folder)\n", - "# pick 3 random feature files and open them in the web browser\n", - "for i in range(3):\n", - " feature_file = feature_files[i]\n", - " url = f\"file://{os.path.abspath(runner.dashboard_folder)}/{feature_file}\"\n", - " webbrowser.open(url)\n", - " print(url)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mats_sae_training", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tutorials/logits_lens_with_features.ipynb b/tutorials/logits_lens_with_features.ipynb index f1077301..409ffebb 100644 --- a/tutorials/logits_lens_with_features.ipynb +++ b/tutorials/logits_lens_with_features.ipynb @@ -1,806 +1,806 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Understanding SAE Features with the Logit Lens\n", - "\n", - "This notebook demonstrates how to use the mats_sae_training library to perform the analysis documented the post \"[Understanding SAE Features with the Logit Lens](https://www.alignmentforum.org/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens)\". \n", - "\n", - "As such, the notebook will include sections for:\n", - "- Loading in GPT2-Small Residual Stream SAEs from Huggingface. \n", - "- Performing Virtual Weight Based Analysis of features (specifically looking at the logit weight distributions).\n", - "- Programmatically opening neuronpedia tabs to engage with public dashboards on [neuronpedia](https://www.neuronpedia.org/).\n", - "- Performing Token Set Enrichment Analysis (based on Gene Set Enrichment Analysis). " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set Up\n", - "\n", - "Here we'll load various functions for things like:\n", - "- downloading and loading our SAEs from huggingface. \n", - "- opening neuronpedia from a jupyter cell. \n", - "- calculating statistics of the logit weight distributions. \n", - "- performing Token Set Enrichment Analysis (TSEA) and plotting the results." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import plotly_express as px\n", - "\n", - "from transformer_lens import HookedTransformer\n", - "\n", - "# Model Loading\n", - "from sae_analysis.toolkit import get_all_gpt2_small_saes, open_neuronpedia\n", - "\n", - "# Virtual Weight / Feature Statistics Functions\n", - "from sae_analysis.feature_statistics import get_all_stats_dfs, get_W_U_W_dec_stats_df\n", - "\n", - "# Enrichment Analysis Functions\n", - "from sae_analysis.tsea import (\n", - " get_enrichment_df,\n", - " manhattan_plot_enrichment_scores,\n", - " plot_top_k_feature_projections_by_token_and_category,\n", - ")\n", - "from sae_analysis.tsea import (\n", - " get_baby_name_sets,\n", - " get_letter_gene_sets,\n", - " generate_pos_sets,\n", - " get_test_gene_sets,\n", - " get_gene_set_from_regex,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Loading GPT2 Small and SAE Weights\n", - "\n", - "This will take a while the first time you run it, but will be quick thereafter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", - "gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities = get_all_gpt2_small_saes()\n", - "W_dec_all_layers = torch.stack(\n", - " [\n", - " gpt2_small_sparse_autoencoders[i].W_dec.detach().cpu()\n", - " for i in gpt2_small_sparse_autoencoders.keys()\n", - " ],\n", - " dim=0,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Statistical Properties of Feature Logit Distributions\n", - "\n", - "In the post I study layer 8 (for no particular reason). At the end of this notebook is code for visualizing these statistics across all layers. Feel free to change the layer here and explore different layers. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# In the post, I focus on layer 8\n", - "layer = 8\n", - "\n", - "# get the corresponding SAE and feature sparsities.\n", - "sparse_autoencoder = gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"]\n", - "log_feature_sparsity = gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"].cpu()\n", - "W_dec = sparse_autoencoder.W_dec.detach().cpu()\n", - "\n", - "# calculate the statistics of the logit weight distributions\n", - "W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(\n", - " W_dec, model, cosine_sim=False\n", - ")\n", - "W_U_stats_df_dec[\"sparsity\"] = (\n", - " log_feature_sparsity # add feature sparsity since it is often interesting.\n", - ")\n", - "display(W_U_stats_df_dec)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the distribution of the 3rd / 4th moments. I found these aren't as useful on their own as joint distributions can be.\n", - "px.histogram(\n", - " W_U_stats_df_dec,\n", - " x=\"skewness\",\n", - " width=800,\n", - " height=300,\n", - " nbins=1000,\n", - " title=\"Skewness of the Logit Weight Distributions\",\n", - ").show()\n", - "\n", - "px.histogram(\n", - " W_U_stats_df_dec,\n", - " x=np.log10(W_U_stats_df_dec[\"kurtosis\"]),\n", - " width=800,\n", - " height=300,\n", - " nbins=1000,\n", - " title=\"Kurtosis of the Logit Weight Distributions\",\n", - ").show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " W_U_stats_df_dec,\n", - " x=\"skewness\",\n", - " y=\"kurtosis\",\n", - " color=\"std\",\n", - " color_continuous_scale=\"Portland\",\n", - " hover_name=\"feature\",\n", - " width=800,\n", - " height=500,\n", - " log_y=True, # Kurtosis has larger outliers so logging creates a nicer scale.\n", - " labels={\"x\": \"Skewness\", \"y\": \"Kurtosis\", \"color\": \"Standard Deviation\"},\n", - " title=f\"Layer {8}: Skewness vs Kurtosis of the Logit Weight Distributions\",\n", - ")\n", - "\n", - "# decrease point size\n", - "fig.update_traces(marker=dict(size=3))\n", - "\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# then you can query accross combinations of the statistics to find features of interest and open them in neuronpedia.\n", - "tmp_df = W_U_stats_df_dec[[\"feature\", \"skewness\", \"kurtosis\", \"std\"]]\n", - "# tmp_df = tmp_df[(tmp_df[\"std\"] > 0.04)]\n", - "# tmp_df = tmp_df[(tmp_df[\"skewness\"] > 0.65)]\n", - "tmp_df = tmp_df[(tmp_df[\"skewness\"] > 3)]\n", - "tmp_df = tmp_df.sort_values(\"skewness\", ascending=False).head(10)\n", - "display(tmp_df)\n", - "\n", - "# if desired, open the features in neuronpedia\n", - "for feature in tmp_df.feature:\n", - " open_neuronpedia(feature, layer=layer)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Token Set Enrichment Analysis\n", - "\n", - "We now proceed to token set enrichment analysis. I highly recommend reading my AlignmentForum post (espeically the case studies) before reading too much into any of these results. \n", - "Also read this [post](https://transformer-circuits.pub/2024/qualitative-essay/index.html) for good general perspectives on statistics here. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Defining Our Token Sets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get the vocab we need to filter to formulate token sets.\n", - "vocab = model.tokenizer.get_vocab() # type: ignore\n", - "\n", - "# make a regex dictionary to specify more sets.\n", - "regex_dict = {\n", - " \"starts_with_space\": r\"Ġ.*\",\n", - " \"starts_with_capital\": r\"^Ġ*[A-Z].*\",\n", - " \"starts_with_lower\": r\"^Ġ*[a-z].*\",\n", - " \"all_digits\": r\"^Ġ*\\d+$\",\n", - " \"is_punctuation\": r\"^[^\\w\\s]+$\",\n", - " \"contains_close_bracket\": r\".*\\).*\",\n", - " \"contains_open_bracket\": r\".*\\(.*\",\n", - " \"all_caps\": r\"Ġ*[A-Z]+$\",\n", - " \"1 digit\": r\"Ġ*\\d{1}$\",\n", - " \"2 digits\": r\"Ġ*\\d{2}$\",\n", - " \"3 digits\": r\"Ġ*\\d{3}$\",\n", - " \"4 digits\": r\"Ġ*\\d{4}$\",\n", - " \"length_1\": r\"^Ġ*\\w{1}$\",\n", - " \"length_2\": r\"^Ġ*\\w{2}$\",\n", - " \"length_3\": r\"^Ġ*\\w{3}$\",\n", - " \"length_4\": r\"^Ġ*\\w{4}$\",\n", - " \"length_5\": r\"^Ġ*\\w{5}$\",\n", - "}\n", - "\n", - "# print size of gene sets\n", - "all_token_sets = get_letter_gene_sets(vocab)\n", - "for key, value in regex_dict.items():\n", - " gene_set = get_gene_set_from_regex(vocab, value)\n", - " all_token_sets[key] = gene_set\n", - "\n", - "# some other sets that can be interesting\n", - "baby_name_sets = get_baby_name_sets(vocab)\n", - "pos_sets = generate_pos_sets(vocab)\n", - "arbitrary_sets = get_test_gene_sets(model)\n", - "\n", - "all_token_sets = {**all_token_sets, **pos_sets}\n", - "all_token_sets = {**all_token_sets, **arbitrary_sets}\n", - "all_token_sets = {**all_token_sets, **baby_name_sets}\n", - "\n", - "# for each gene set, convert to string and print the first 5 tokens\n", - "for token_set_name, gene_set in sorted(\n", - " all_token_sets.items(), key=lambda x: len(x[1]), reverse=True\n", - "):\n", - " tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n", - " print(f\"{token_set_name}, has {len(gene_set)} genes\")\n", - " print(tokens)\n", - " print(\"----\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Performing Token Set Enrichment Analysis\n", - "\n", - "Below we perform token set enrichment analysis on various token sets. In practice, we'd likely perform tests accross all tokens and large libraries of sets simultaneously but to make it easier to run, we look at features with higher skew and select of a few token sets at a time to consider." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "features_ordered_by_skew = (\n", - " W_U_stats_df_dec[\"skewness\"].sort_values(ascending=False).head(5000).index.to_list()\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# filter our list.\n", - "token_sets_index = [\n", - " \"starts_with_space\",\n", - " \"starts_with_capital\",\n", - " \"all_digits\",\n", - " \"is_punctuation\",\n", - " \"all_caps\",\n", - "]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "\n", - "# calculate the enrichment scores\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, # use the logit weight values as our rankings over tokens.\n", - " features_ordered_by_skew, # subset by these features\n", - " token_set_selected, # use token_sets\n", - ")\n", - "\n", - "manhattan_plot_enrichment_scores(\n", - " df_enrichment_scores, label_threshold=0, top_n=3 # use our enrichment scores\n", - ").show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T,\n", - " x=\"starts_with_space\",\n", - " y=\"starts_with_capital\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " labels={\n", - " \"starts_with_space\": \"Starts with Space\",\n", - " \"starts_with_capital\": \"Starts with Capital\",\n", - " },\n", - " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", - " height=800,\n", - " width=800,\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"1 digit\", \"2 digits\", \"3 digits\", \"4 digits\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_PRP\", \"nltk_pos_VBZ\", \"nltk_pos_NNP\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_VBN\", \"nltk_pos_VBG\", \"nltk_pos_VB\", \"nltk_pos_VBD\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_WP\", \"nltk_pos_RBR\", \"nltk_pos_WDT\", \"nltk_pos_RB\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"a\", \"e\", \"i\", \"o\", \"u\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"negative_words\", \"positive_words\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x))\n", - " .T.reset_index()\n", - " .rename(columns={\"index\": \"feature\"}),\n", - " x=\"negative_words\",\n", - " y=\"positive_words\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " labels={\n", - " \"starts_with_space\": \"Starts with Space\",\n", - " \"starts_with_capital\": \"Starts with Capital\",\n", - " },\n", - " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", - " height=800,\n", - " width=800,\n", - " hover_name=\"feature\",\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"contains_close_bracket\", \"contains_open_bracket\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\n", - " \"1910's\",\n", - " \"1920's\",\n", - " \"1930's\",\n", - " \"1940's\",\n", - " \"1950's\",\n", - " \"1960's\",\n", - " \"1970's\",\n", - " \"1980's\",\n", - " \"1990's\",\n", - " \"2000's\",\n", - " \"2010's\",\n", - "]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"positive_words\", \"negative_words\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores, label_threshold=0.98).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "token_sets_index = [\"boys_names\", \"girls_names\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tmp_df = df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T\n", - "color = (\n", - " W_U_stats_df_dec.sort_values(\"skewness\", ascending=False)\n", - " .head(5000)[\"skewness\"]\n", - " .values\n", - ")\n", - "fig = px.scatter(\n", - " tmp_df.reset_index().rename(columns={\"index\": \"feature\"}),\n", - " x=\"boys_names\",\n", - " y=\"girls_names\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " # color = color,\n", - " labels={\n", - " \"boys_names\": \"Enrichment Score (Boys Names)\",\n", - " \"girls_names\": \"Enrichment Score (Girls Names)\",\n", - " },\n", - " height=600,\n", - " width=800,\n", - " hover_name=\"feature\",\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=3), selector=dict(mode=\"markers\"))\n", - "# annotate any features where the absolute distance between boys names and girls names > 3\n", - "for feature in df_enrichment_scores.columns:\n", - " if abs(tmp_df[\"boys_names\"][feature] - tmp_df[\"girls_names\"][feature]) > 2.9:\n", - " fig.add_annotation(\n", - " x=tmp_df[\"boys_names\"][feature] - 0.4,\n", - " y=tmp_df[\"girls_names\"][feature] + 0.1,\n", - " text=f\"{feature}\",\n", - " showarrow=False,\n", - " )\n", - "\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Digging into Particular Features\n", - "\n", - "When we do these enrichments, I generate the logit weight histograms by category using the following function. It's important to make sure the categories you group by are in the columns of df_enrichment_scores." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for category in [\"starts_with_space\"]:\n", - " plot_top_k_feature_projections_by_token_and_category(\n", - " token_set_selected,\n", - " df_enrichment_scores,\n", - " category=category,\n", - " dec_projection_onto_W_U=dec_projection_onto_W_U,\n", - " model=model,\n", - " log_y=False,\n", - " histnorm=None,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Appendix Results: Logit Weight distribution Statistics Accross All Layers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W_U_stats_df_dec_all_layers = get_all_stats_dfs(\n", - " gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities, model, cosine_sim=True\n", - ")\n", - "\n", - "display(W_U_stats_df_dec_all_layers.shape)\n", - "display(W_U_stats_df_dec_all_layers.head())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Let's plot the percentiles of the skewness and kurtosis by layer\n", - "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"skewness\"].describe(\n", - " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", - ")\n", - "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", - "\n", - "fig = px.area(\n", - " tmp_df,\n", - " title=\"Kurtosis by Layer\",\n", - " width=800,\n", - " height=600,\n", - " color_discrete_sequence=px.colors.sequential.Turbo,\n", - ").show()\n", - "\n", - "\n", - "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"kurtosis\"].describe(\n", - " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", - ")\n", - "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", - "\n", - "fig = px.area(\n", - " tmp_df,\n", - " title=\"Kurtosis by Layer\",\n", - " width=800,\n", - " height=600,\n", - " color_discrete_sequence=px.colors.sequential.Turbo,\n", - ")\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# let's make a pretty color scheme\n", - "from plotly.colors import n_colors\n", - "\n", - "colors = n_colors(\"rgb(5, 200, 200)\", \"rgb(200, 10, 10)\", 13, colortype=\"rgb\")\n", - "\n", - "# Make a box plot of the skewness by layer\n", - "fig = px.box(\n", - " W_U_stats_df_dec_all_layers,\n", - " x=\"layer\",\n", - " y=\"skewness\",\n", - " color=\"layer\",\n", - " color_discrete_sequence=colors,\n", - " height=600,\n", - " width=1200,\n", - " title=\"Skewness cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", - " labels={\"layer\": \"Layer\", \"skewnss\": \"Skewness\"},\n", - ")\n", - "fig.update_xaxes(showticklabels=True, dtick=1)\n", - "\n", - "# increase font size\n", - "fig.update_layout(font=dict(size=16))\n", - "fig.show()\n", - "\n", - "# Make a box plot of the skewness by layer\n", - "fig = px.box(\n", - " W_U_stats_df_dec_all_layers,\n", - " x=\"layer\",\n", - " y=\"kurtosis\",\n", - " color=\"layer\",\n", - " color_discrete_sequence=colors,\n", - " height=600,\n", - " width=1200,\n", - " log_y=True,\n", - " title=\"log kurtosis cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", - " labels={\"layer\": \"Layer\", \"kurtosis\": \"Log Kurtosis\"},\n", - ")\n", - "fig.update_xaxes(showticklabels=True, dtick=1)\n", - "\n", - "# increase font size\n", - "fig.update_layout(font=dict(size=16))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# scatter\n", - "fig = px.scatter(\n", - " W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.log_feature_sparsity >= -9],\n", - " # W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.layer == 8],\n", - " x=\"skewness\",\n", - " y=\"kurtosis\",\n", - " color=\"std\",\n", - " color_continuous_scale=\"Portland\",\n", - " hover_name=\"feature\",\n", - " # color_continuous_midpoint = 0,\n", - " # range_color = [-4,-1],\n", - " log_y=True,\n", - " height=800,\n", - " # width = 2000,\n", - " # facet_col=\"layer\",\n", - " # facet_col_wrap=5,\n", - " animation_frame=\"layer\",\n", - ")\n", - "fig.update_yaxes(matches=None)\n", - "fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))\n", - "\n", - "# decrease point size\n", - "fig.update_traces(marker=dict(size=5))\n", - "fig.show()\n", - "fig.write_html(\"skewness_kurtosis_scatter_all_layers.html\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mats_sae_training", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Understanding SAE Features with the Logit Lens\n", + "\n", + "This notebook demonstrates how to use the mats_sae_training library to perform the analysis documented the post \"[Understanding SAE Features with the Logit Lens](https://www.alignmentforum.org/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens)\". \n", + "\n", + "As such, the notebook will include sections for:\n", + "- Loading in GPT2-Small Residual Stream SAEs from Huggingface. \n", + "- Performing Virtual Weight Based Analysis of features (specifically looking at the logit weight distributions).\n", + "- Programmatically opening neuronpedia tabs to engage with public dashboards on [neuronpedia](https://www.neuronpedia.org/).\n", + "- Performing Token Set Enrichment Analysis (based on Gene Set Enrichment Analysis). " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Up\n", + "\n", + "Here we'll load various functions for things like:\n", + "- downloading and loading our SAEs from huggingface. \n", + "- opening neuronpedia from a jupyter cell. \n", + "- calculating statistics of the logit weight distributions. \n", + "- performing Token Set Enrichment Analysis (TSEA) and plotting the results." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import plotly_express as px\n", + "\n", + "from transformer_lens import HookedTransformer\n", + "\n", + "# Model Loading\n", + "from sae_lens.analysis.toolkit import get_all_gpt2_small_saes, open_neuronpedia\n", + "\n", + "# Virtual Weight / Feature Statistics Functions\n", + "from sae_lens.analysis.feature_statistics import get_all_stats_dfs, get_W_U_W_dec_stats_df\n", + "\n", + "# Enrichment Analysis Functions\n", + "from sae_lens.analysis.tsea import (\n", + " get_enrichment_df,\n", + " manhattan_plot_enrichment_scores,\n", + " plot_top_k_feature_projections_by_token_and_category,\n", + ")\n", + "from sae_lens.analysis.tsea import (\n", + " get_baby_name_sets,\n", + " get_letter_gene_sets,\n", + " generate_pos_sets,\n", + " get_test_gene_sets,\n", + " get_gene_set_from_regex,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading GPT2 Small and SAE Weights\n", + "\n", + "This will take a while the first time you run it, but will be quick thereafter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", + "gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities = get_all_gpt2_small_saes()\n", + "W_dec_all_layers = torch.stack(\n", + " [\n", + " gpt2_small_sparse_autoencoders[i].W_dec.detach().cpu()\n", + " for i in gpt2_small_sparse_autoencoders.keys()\n", + " ],\n", + " dim=0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Statistical Properties of Feature Logit Distributions\n", + "\n", + "In the post I study layer 8 (for no particular reason). At the end of this notebook is code for visualizing these statistics across all layers. Feel free to change the layer here and explore different layers. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# In the post, I focus on layer 8\n", + "layer = 8\n", + "\n", + "# get the corresponding SAE and feature sparsities.\n", + "sparse_autoencoder = gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"]\n", + "log_feature_sparsity = gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"].cpu()\n", + "W_dec = sparse_autoencoder.W_dec.detach().cpu()\n", + "\n", + "# calculate the statistics of the logit weight distributions\n", + "W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(\n", + " W_dec, model, cosine_sim=False\n", + ")\n", + "W_U_stats_df_dec[\"sparsity\"] = (\n", + " log_feature_sparsity # add feature sparsity since it is often interesting.\n", + ")\n", + "display(W_U_stats_df_dec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's look at the distribution of the 3rd / 4th moments. I found these aren't as useful on their own as joint distributions can be.\n", + "px.histogram(\n", + " W_U_stats_df_dec,\n", + " x=\"skewness\",\n", + " width=800,\n", + " height=300,\n", + " nbins=1000,\n", + " title=\"Skewness of the Logit Weight Distributions\",\n", + ").show()\n", + "\n", + "px.histogram(\n", + " W_U_stats_df_dec,\n", + " x=np.log10(W_U_stats_df_dec[\"kurtosis\"]),\n", + " width=800,\n", + " height=300,\n", + " nbins=1000,\n", + " title=\"Kurtosis of the Logit Weight Distributions\",\n", + ").show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " W_U_stats_df_dec,\n", + " x=\"skewness\",\n", + " y=\"kurtosis\",\n", + " color=\"std\",\n", + " color_continuous_scale=\"Portland\",\n", + " hover_name=\"feature\",\n", + " width=800,\n", + " height=500,\n", + " log_y=True, # Kurtosis has larger outliers so logging creates a nicer scale.\n", + " labels={\"x\": \"Skewness\", \"y\": \"Kurtosis\", \"color\": \"Standard Deviation\"},\n", + " title=f\"Layer {8}: Skewness vs Kurtosis of the Logit Weight Distributions\",\n", + ")\n", + "\n", + "# decrease point size\n", + "fig.update_traces(marker=dict(size=3))\n", + "\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# then you can query accross combinations of the statistics to find features of interest and open them in neuronpedia.\n", + "tmp_df = W_U_stats_df_dec[[\"feature\", \"skewness\", \"kurtosis\", \"std\"]]\n", + "# tmp_df = tmp_df[(tmp_df[\"std\"] > 0.04)]\n", + "# tmp_df = tmp_df[(tmp_df[\"skewness\"] > 0.65)]\n", + "tmp_df = tmp_df[(tmp_df[\"skewness\"] > 3)]\n", + "tmp_df = tmp_df.sort_values(\"skewness\", ascending=False).head(10)\n", + "display(tmp_df)\n", + "\n", + "# if desired, open the features in neuronpedia\n", + "for feature in tmp_df.feature:\n", + " open_neuronpedia(feature, layer=layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Token Set Enrichment Analysis\n", + "\n", + "We now proceed to token set enrichment analysis. I highly recommend reading my AlignmentForum post (espeically the case studies) before reading too much into any of these results. \n", + "Also read this [post](https://transformer-circuits.pub/2024/qualitative-essay/index.html) for good general perspectives on statistics here. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining Our Token Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get the vocab we need to filter to formulate token sets.\n", + "vocab = model.tokenizer.get_vocab() # type: ignore\n", + "\n", + "# make a regex dictionary to specify more sets.\n", + "regex_dict = {\n", + " \"starts_with_space\": r\"Ġ.*\",\n", + " \"starts_with_capital\": r\"^Ġ*[A-Z].*\",\n", + " \"starts_with_lower\": r\"^Ġ*[a-z].*\",\n", + " \"all_digits\": r\"^Ġ*\\d+$\",\n", + " \"is_punctuation\": r\"^[^\\w\\s]+$\",\n", + " \"contains_close_bracket\": r\".*\\).*\",\n", + " \"contains_open_bracket\": r\".*\\(.*\",\n", + " \"all_caps\": r\"Ġ*[A-Z]+$\",\n", + " \"1 digit\": r\"Ġ*\\d{1}$\",\n", + " \"2 digits\": r\"Ġ*\\d{2}$\",\n", + " \"3 digits\": r\"Ġ*\\d{3}$\",\n", + " \"4 digits\": r\"Ġ*\\d{4}$\",\n", + " \"length_1\": r\"^Ġ*\\w{1}$\",\n", + " \"length_2\": r\"^Ġ*\\w{2}$\",\n", + " \"length_3\": r\"^Ġ*\\w{3}$\",\n", + " \"length_4\": r\"^Ġ*\\w{4}$\",\n", + " \"length_5\": r\"^Ġ*\\w{5}$\",\n", + "}\n", + "\n", + "# print size of gene sets\n", + "all_token_sets = get_letter_gene_sets(vocab)\n", + "for key, value in regex_dict.items():\n", + " gene_set = get_gene_set_from_regex(vocab, value)\n", + " all_token_sets[key] = gene_set\n", + "\n", + "# some other sets that can be interesting\n", + "baby_name_sets = get_baby_name_sets(vocab)\n", + "pos_sets = generate_pos_sets(vocab)\n", + "arbitrary_sets = get_test_gene_sets(model)\n", + "\n", + "all_token_sets = {**all_token_sets, **pos_sets}\n", + "all_token_sets = {**all_token_sets, **arbitrary_sets}\n", + "all_token_sets = {**all_token_sets, **baby_name_sets}\n", + "\n", + "# for each gene set, convert to string and print the first 5 tokens\n", + "for token_set_name, gene_set in sorted(\n", + " all_token_sets.items(), key=lambda x: len(x[1]), reverse=True\n", + "):\n", + " tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n", + " print(f\"{token_set_name}, has {len(gene_set)} genes\")\n", + " print(tokens)\n", + " print(\"----\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performing Token Set Enrichment Analysis\n", + "\n", + "Below we perform token set enrichment analysis on various token sets. In practice, we'd likely perform tests accross all tokens and large libraries of sets simultaneously but to make it easier to run, we look at features with higher skew and select of a few token sets at a time to consider." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features_ordered_by_skew = (\n", + " W_U_stats_df_dec[\"skewness\"].sort_values(ascending=False).head(5000).index.to_list()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# filter our list.\n", + "token_sets_index = [\n", + " \"starts_with_space\",\n", + " \"starts_with_capital\",\n", + " \"all_digits\",\n", + " \"is_punctuation\",\n", + " \"all_caps\",\n", + "]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "\n", + "# calculate the enrichment scores\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, # use the logit weight values as our rankings over tokens.\n", + " features_ordered_by_skew, # subset by these features\n", + " token_set_selected, # use token_sets\n", + ")\n", + "\n", + "manhattan_plot_enrichment_scores(\n", + " df_enrichment_scores, label_threshold=0, top_n=3 # use our enrichment scores\n", + ").show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T,\n", + " x=\"starts_with_space\",\n", + " y=\"starts_with_capital\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " labels={\n", + " \"starts_with_space\": \"Starts with Space\",\n", + " \"starts_with_capital\": \"Starts with Capital\",\n", + " },\n", + " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", + " height=800,\n", + " width=800,\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"1 digit\", \"2 digits\", \"3 digits\", \"4 digits\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_PRP\", \"nltk_pos_VBZ\", \"nltk_pos_NNP\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_VBN\", \"nltk_pos_VBG\", \"nltk_pos_VB\", \"nltk_pos_VBD\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_WP\", \"nltk_pos_RBR\", \"nltk_pos_WDT\", \"nltk_pos_RB\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"a\", \"e\", \"i\", \"o\", \"u\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"negative_words\", \"positive_words\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x))\n", + " .T.reset_index()\n", + " .rename(columns={\"index\": \"feature\"}),\n", + " x=\"negative_words\",\n", + " y=\"positive_words\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " labels={\n", + " \"starts_with_space\": \"Starts with Space\",\n", + " \"starts_with_capital\": \"Starts with Capital\",\n", + " },\n", + " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", + " height=800,\n", + " width=800,\n", + " hover_name=\"feature\",\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"contains_close_bracket\", \"contains_open_bracket\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\n", + " \"1910's\",\n", + " \"1920's\",\n", + " \"1930's\",\n", + " \"1940's\",\n", + " \"1950's\",\n", + " \"1960's\",\n", + " \"1970's\",\n", + " \"1980's\",\n", + " \"1990's\",\n", + " \"2000's\",\n", + " \"2010's\",\n", + "]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"positive_words\", \"negative_words\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores, label_threshold=0.98).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "token_sets_index = [\"boys_names\", \"girls_names\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tmp_df = df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T\n", + "color = (\n", + " W_U_stats_df_dec.sort_values(\"skewness\", ascending=False)\n", + " .head(5000)[\"skewness\"]\n", + " .values\n", + ")\n", + "fig = px.scatter(\n", + " tmp_df.reset_index().rename(columns={\"index\": \"feature\"}),\n", + " x=\"boys_names\",\n", + " y=\"girls_names\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " # color = color,\n", + " labels={\n", + " \"boys_names\": \"Enrichment Score (Boys Names)\",\n", + " \"girls_names\": \"Enrichment Score (Girls Names)\",\n", + " },\n", + " height=600,\n", + " width=800,\n", + " hover_name=\"feature\",\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=3), selector=dict(mode=\"markers\"))\n", + "# annotate any features where the absolute distance between boys names and girls names > 3\n", + "for feature in df_enrichment_scores.columns:\n", + " if abs(tmp_df[\"boys_names\"][feature] - tmp_df[\"girls_names\"][feature]) > 2.9:\n", + " fig.add_annotation(\n", + " x=tmp_df[\"boys_names\"][feature] - 0.4,\n", + " y=tmp_df[\"girls_names\"][feature] + 0.1,\n", + " text=f\"{feature}\",\n", + " showarrow=False,\n", + " )\n", + "\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Digging into Particular Features\n", + "\n", + "When we do these enrichments, I generate the logit weight histograms by category using the following function. It's important to make sure the categories you group by are in the columns of df_enrichment_scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for category in [\"starts_with_space\"]:\n", + " plot_top_k_feature_projections_by_token_and_category(\n", + " token_set_selected,\n", + " df_enrichment_scores,\n", + " category=category,\n", + " dec_projection_onto_W_U=dec_projection_onto_W_U,\n", + " model=model,\n", + " log_y=False,\n", + " histnorm=None,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Appendix Results: Logit Weight distribution Statistics Accross All Layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "W_U_stats_df_dec_all_layers = get_all_stats_dfs(\n", + " gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities, model, cosine_sim=True\n", + ")\n", + "\n", + "display(W_U_stats_df_dec_all_layers.shape)\n", + "display(W_U_stats_df_dec_all_layers.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's plot the percentiles of the skewness and kurtosis by layer\n", + "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"skewness\"].describe(\n", + " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", + ")\n", + "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", + "\n", + "fig = px.area(\n", + " tmp_df,\n", + " title=\"Kurtosis by Layer\",\n", + " width=800,\n", + " height=600,\n", + " color_discrete_sequence=px.colors.sequential.Turbo,\n", + ").show()\n", + "\n", + "\n", + "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"kurtosis\"].describe(\n", + " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", + ")\n", + "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", + "\n", + "fig = px.area(\n", + " tmp_df,\n", + " title=\"Kurtosis by Layer\",\n", + " width=800,\n", + " height=600,\n", + " color_discrete_sequence=px.colors.sequential.Turbo,\n", + ")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# let's make a pretty color scheme\n", + "from plotly.colors import n_colors\n", + "\n", + "colors = n_colors(\"rgb(5, 200, 200)\", \"rgb(200, 10, 10)\", 13, colortype=\"rgb\")\n", + "\n", + "# Make a box plot of the skewness by layer\n", + "fig = px.box(\n", + " W_U_stats_df_dec_all_layers,\n", + " x=\"layer\",\n", + " y=\"skewness\",\n", + " color=\"layer\",\n", + " color_discrete_sequence=colors,\n", + " height=600,\n", + " width=1200,\n", + " title=\"Skewness cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", + " labels={\"layer\": \"Layer\", \"skewnss\": \"Skewness\"},\n", + ")\n", + "fig.update_xaxes(showticklabels=True, dtick=1)\n", + "\n", + "# increase font size\n", + "fig.update_layout(font=dict(size=16))\n", + "fig.show()\n", + "\n", + "# Make a box plot of the skewness by layer\n", + "fig = px.box(\n", + " W_U_stats_df_dec_all_layers,\n", + " x=\"layer\",\n", + " y=\"kurtosis\",\n", + " color=\"layer\",\n", + " color_discrete_sequence=colors,\n", + " height=600,\n", + " width=1200,\n", + " log_y=True,\n", + " title=\"log kurtosis cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", + " labels={\"layer\": \"Layer\", \"kurtosis\": \"Log Kurtosis\"},\n", + ")\n", + "fig.update_xaxes(showticklabels=True, dtick=1)\n", + "\n", + "# increase font size\n", + "fig.update_layout(font=dict(size=16))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# scatter\n", + "fig = px.scatter(\n", + " W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.log_feature_sparsity >= -9],\n", + " # W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.layer == 8],\n", + " x=\"skewness\",\n", + " y=\"kurtosis\",\n", + " color=\"std\",\n", + " color_continuous_scale=\"Portland\",\n", + " hover_name=\"feature\",\n", + " # color_continuous_midpoint = 0,\n", + " # range_color = [-4,-1],\n", + " log_y=True,\n", + " height=800,\n", + " # width = 2000,\n", + " # facet_col=\"layer\",\n", + " # facet_col_wrap=5,\n", + " animation_frame=\"layer\",\n", + ")\n", + "fig.update_yaxes(matches=None)\n", + "fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))\n", + "\n", + "# decrease point size\n", + "fig.update_traces(marker=dict(size=5))\n", + "fig.show()\n", + "fig.write_html(\"skewness_kurtosis_scatter_all_layers.html\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mats_sae_training", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index fe95fd67..79474377 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -58,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sae_analysis.neuronpedia_runner import NeuronpediaRunner\n", + "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", "NP_OUTPUT_FOLDER = \"../neuronpedia_outputs\"\n", "\n", diff --git a/tutorials/neuronpedia/np_runner_batch.py b/tutorials/neuronpedia/np_runner_batch.py index 2a6c52ce..23a72ec3 100644 --- a/tutorials/neuronpedia/np_runner_batch.py +++ b/tutorials/neuronpedia/np_runner_batch.py @@ -15,7 +15,7 @@ f"../../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt" ) -from sae_analysis.neuronpedia_runner import NeuronpediaRunner +from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner NP_OUTPUT_FOLDER = "../../neuronpedia_outputs"