Skip to content

Commit

Permalink
Improve getting started docs for Activation Cache (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Oct 21, 2023
1 parent 8c966e6 commit 19ab304
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 12 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"transformerlens",
"troitskii",
"unembed",
"unembedded",
"unembedding",
"unigram",
"virtualenvs",
Expand Down
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@
"pylint.importStrategy": "fromEnvironment",
"notebook.formatOnCellExecution": true,
"notebook.formatOnSave.enabled": true,
"cSpell.words": [
"accum"
],
}
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@

napoleon_include_init_with_doc = True
napoleon_use_admonition_for_notes = True
napoleon_custom_sections = ["Motivation:", "Warning:"]
napoleon_custom_sections = [
"Motivation:",
"Warning:",
"Getting Started:",
]

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand Down
121 changes: 110 additions & 11 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Activation Cache.
The core functionality of TransformerLens is to cache and edit the activations of a model. The
:class:`ActivationCache` is designed to help do the caching part of this - storing all activations
in a single place.
The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
important activations from a forward pass of the model, and provides a variety of helper functions
to investigate them.
Getting Started:
When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
class first, including the examples, and then skimming the available methods. You can then refer
back to these docs depending on what you need to do.
"""
from __future__ import annotations

Expand All @@ -24,13 +30,56 @@
class ActivationCache:
"""Activation Cache.
A wrapper around a dictionary of cached activations from a model run, with a variety of helper
functions.
A wrapper that stores all important activations from a forward pass of the model, and provides a
variety of helper functions to investigate them.
The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
important activations from a forward pass of the model, and provides a variety of helper
functions to investigate them. The common way to access it is to run the model with
:meth:`transformer_lens.HookedTransformer.run_with_cache`.
Examples:
When investigating a particular behaviour of a modal, a very common first step is to try and
understand which components of the model are most responsible for that behaviour. For example,
if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
the model predicting "road". This kind of analysis commonly falls under the category of "logit
attribution" or "direct logit attribution" (DLA).
>>> from transformer_lens import HookedTransformer
>>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
Loaded pretrained model tiny-stories-1M into HookedTransformer
>>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
>>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
>>> print(labels[0:3])
['embed', 'pos_embed', '0_attn_out']
>>> answer = " road" # Note the proceeding space to match the model's tokenization
>>> logit_attrs = cache.logit_attrs(residual_stream, answer)
>>> print(logit_attrs.shape) # Attention layers
torch.Size([10, 1, 7])
This is designed to be used with :class:`transformer_lens.HookedTransformer`, and will not
work with other models. It's also designed to be used with all activations of
:class:`transformer_lens.HookedTransformer` being cached, and some internal methods will break
without that.
>>> most_important_component_idx = torch.argmax(logit_attrs)
>>> print(labels[most_important_component_idx])
3_attn_out
You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
residual stream by individual component (mlp neurons and individual attention heads). This
creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
Equally you might want to find out if the model struggles to construct such excellent jokes
until the very last layers, or if it is trivial and the first few layers are enough. This kind
of analysis is called "logit lens", and you can find out more about how to do that with
:meth:`ActivationCache.accumulated_resid`.
Warning:
:class:`ActivationCache` is designed to be used with
:class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
cached, and some internal methods will break without that.
The biggest footgun and source of bugs in this code will be keeping track of indexes,
dimensions, and the numbers of each. There are several kinds of activations:
Expand Down Expand Up @@ -300,14 +349,64 @@ def accumulated_resid(
To project this into the vocabulary space, remember that there is a final layer norm in most
decoder-only transformers. Therefore, you need to first apply the final layer norm (which
can be done with :meth:`apply_ln_to_stack`), and then multiply by the unembedding matrix
(:math:`W_U`).
can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
If you instead want to look at contributions to the residual stream from each component
(e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
:meth:`get_full_resid_decomposition` if you want contributions broken down further into each
MLP neuron.
Examples:
Logit Lens analysis can be done as follows:
>>> from transformer_lens import HookedTransformer
>>> from einops import einsum
>>> import torch
>>> import pandas as pd
>>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
Loaded pretrained model tiny-stories-1M into HookedTransformer
>>> prompt = "Why did the chicken cross the"
>>> answer = " road"
>>> logits, cache = model.run_with_cache("Why did the chicken cross the")
>>> answer_token = model.to_single_token(answer)
>>> print(answer_token)
2975
>>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
>>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
>>> print(last_token_accum.shape) # layer, d_model
torch.Size([9, 64])
>>> W_U = model.W_U
>>> print(W_U.shape)
torch.Size([64, 50257])
>>> layers_unembedded = einsum(
... last_token_accum,
... W_U,
... "layer d_model, d_model d_vocab -> layer d_vocab"
... )
>>> print(layers_unembedded.shape)
torch.Size([9, 50257])
>>> # Get the rank of the correct answer by layer
>>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
>>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
>>> print(pd.Series(rank_answer, index=labels))
0_pre 4442
1_pre 382
2_pre 982
3_pre 1160
4_pre 408
5_pre 145
6_pre 78
7_pre 387
final_post 6
dtype: int64
Args:
layer:
The layer to take components up to - by default includes resid_pre for that layer
Expand Down

0 comments on commit 19ab304

Please sign in to comment.