From 19ab304485be57183a03221d163a497c97303c5d Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Sat, 21 Oct 2023 19:00:00 +0800 Subject: [PATCH] Improve getting started docs for Activation Cache (#433) --- .vscode/cspell.json | 1 + .vscode/settings.json | 3 + docs/source/conf.py | 6 +- transformer_lens/ActivationCache.py | 121 +++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 12 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 082c84373..738113240 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -93,6 +93,7 @@ "transformerlens", "troitskii", "unembed", + "unembedded", "unembedding", "unigram", "virtualenvs", diff --git a/.vscode/settings.json b/.vscode/settings.json index f8da66a84..7092a6781 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,4 +17,7 @@ "pylint.importStrategy": "fromEnvironment", "notebook.formatOnCellExecution": true, "notebook.formatOnSave.enabled": true, + "cSpell.words": [ + "accum" + ], } \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 50e8d4985..af38914c0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,7 +39,11 @@ napoleon_include_init_with_doc = True napoleon_use_admonition_for_notes = True -napoleon_custom_sections = ["Motivation:", "Warning:"] +napoleon_custom_sections = [ + "Motivation:", + "Warning:", + "Getting Started:", +] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 6427dc2e1..7ee732005 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -1,8 +1,14 @@ """Activation Cache. -The core functionality of TransformerLens is to cache and edit the activations of a model. The -:class:`ActivationCache` is designed to help do the caching part of this - storing all activations -in a single place. +The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all +important activations from a forward pass of the model, and provides a variety of helper functions +to investigate them. + +Getting Started: + +When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` +class first, including the examples, and then skimming the available methods. You can then refer +back to these docs depending on what you need to do. """ from __future__ import annotations @@ -24,13 +30,56 @@ class ActivationCache: """Activation Cache. - A wrapper around a dictionary of cached activations from a model run, with a variety of helper - functions. + A wrapper that stores all important activations from a forward pass of the model, and provides a + variety of helper functions to investigate them. + + The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all + important activations from a forward pass of the model, and provides a variety of helper + functions to investigate them. The common way to access it is to run the model with + :meth:`transformer_lens.HookedTransformer.run_with_cache`. + + Examples: + + When investigating a particular behaviour of a modal, a very common first step is to try and + understand which components of the model are most responsible for that behaviour. For example, + if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to + understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for + the model predicting "road". This kind of analysis commonly falls under the category of "logit + attribution" or "direct logit attribution" (DLA). + + >>> from transformer_lens import HookedTransformer + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") + Loaded pretrained model tiny-stories-1M into HookedTransformer + + >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") + >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") + >>> print(labels[0:3]) + ['embed', 'pos_embed', '0_attn_out'] + + >>> answer = " road" # Note the proceeding space to match the model's tokenization + >>> logit_attrs = cache.logit_attrs(residual_stream, answer) + >>> print(logit_attrs.shape) # Attention layers + torch.Size([10, 1, 7]) - This is designed to be used with :class:`transformer_lens.HookedTransformer`, and will not - work with other models. It's also designed to be used with all activations of - :class:`transformer_lens.HookedTransformer` being cached, and some internal methods will break - without that. + >>> most_important_component_idx = torch.argmax(logit_attrs) + >>> print(labels[most_important_component_idx]) + 3_attn_out + + You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the + residual stream by individual component (mlp neurons and individual attention heads). This + creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. + + Equally you might want to find out if the model struggles to construct such excellent jokes + until the very last layers, or if it is trivial and the first few layers are enough. This kind + of analysis is called "logit lens", and you can find out more about how to do that with + :meth:`ActivationCache.accumulated_resid`. + + Warning: + + :class:`ActivationCache` is designed to be used with + :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also + designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being + cached, and some internal methods will break without that. The biggest footgun and source of bugs in this code will be keeping track of indexes, dimensions, and the numbers of each. There are several kinds of activations: @@ -300,14 +349,64 @@ def accumulated_resid( To project this into the vocabulary space, remember that there is a final layer norm in most decoder-only transformers. Therefore, you need to first apply the final layer norm (which - can be done with :meth:`apply_ln_to_stack`), and then multiply by the unembedding matrix - (:math:`W_U`). + can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). If you instead want to look at contributions to the residual stream from each component (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or :meth:`get_full_resid_decomposition` if you want contributions broken down further into each MLP neuron. + Examples: + + Logit Lens analysis can be done as follows: + + >>> from transformer_lens import HookedTransformer + >>> from einops import einsum + >>> import torch + >>> import pandas as pd + + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") + Loaded pretrained model tiny-stories-1M into HookedTransformer + + >>> prompt = "Why did the chicken cross the" + >>> answer = " road" + >>> logits, cache = model.run_with_cache("Why did the chicken cross the") + >>> answer_token = model.to_single_token(answer) + >>> print(answer_token) + 2975 + + >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) + >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model + >>> print(last_token_accum.shape) # layer, d_model + torch.Size([9, 64]) + + >>> W_U = model.W_U + >>> print(W_U.shape) + torch.Size([64, 50257]) + + >>> layers_unembedded = einsum( + ... last_token_accum, + ... W_U, + ... "layer d_model, d_model d_vocab -> layer d_vocab" + ... ) + >>> print(layers_unembedded.shape) + torch.Size([9, 50257]) + + >>> # Get the rank of the correct answer by layer + >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) + >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] + >>> print(pd.Series(rank_answer, index=labels)) + 0_pre 4442 + 1_pre 382 + 2_pre 982 + 3_pre 1160 + 4_pre 408 + 5_pre 145 + 6_pre 78 + 7_pre 387 + final_post 6 + dtype: int64 + Args: layer: The layer to take components up to - by default includes resid_pre for that layer