-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jbloom-md
committed
Dec 31, 2023
1 parent
048d267
commit 4c7f6f2
Showing
44 changed files
with
3,206,120 additions
and
1,394 deletions.
There are no files selected for viewing
Binary file not shown.
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import os | ||
import warnings | ||
from functools import partial | ||
from typing import List, Optional, Tuple | ||
|
||
import pandas as pd | ||
import torch | ||
from jaxtyping import Float | ||
from torch import Tensor | ||
from transformer_lens import HookedTransformer, utils | ||
from transformer_lens.components import HookPoint | ||
|
||
from research.joseph.utils import make_token_df | ||
from sae_training.sparse_autoencoder import SparseAutoencoder | ||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
|
||
warnings.filterwarnings("ignore") | ||
|
||
|
||
|
||
# HOOKS | ||
# LAYER_IDX, HEAD_IDX = (10, 7) | ||
LAYER_IDX, HEAD_IDX = (1, 11) | ||
# W_U = model.W_U.clone() | ||
# HEAD_HOOK_RESULT_NAME = utils.get_act_name("result", LAYER_IDX) | ||
HEAD_HOOK_RESULT_NAME = utils.get_act_name("z", LAYER_IDX) | ||
HEAD_HOOK_QUERY_NAME = utils.get_act_name("q", LAYER_IDX) | ||
HEAD_HOOK_RESID_NAME = utils.get_act_name("resid_pre", LAYER_IDX) | ||
# BATCH_SIZE = 10 | ||
|
||
|
||
|
||
|
||
# Get ATTN results | ||
|
||
def get_max_attn_token(tokens, cache, model: HookedTransformer, LAYER_IDX, HEAD_IDX): | ||
tokens = tokens.to("cpu") | ||
pattern_name = utils.get_act_name("pattern", LAYER_IDX) | ||
pattern = cache[pattern_name][0,HEAD_IDX].detach().cpu() | ||
max_idx_pos = pattern.argmax(dim=-1) | ||
max_idx_token_id = torch.gather(tokens, dim=-1, index=max_idx_pos.unsqueeze(-1).T) | ||
max_idx_tok = model.to_string(max_idx_token_id.T) | ||
max_idx_tok_value = pattern.max(dim=1).values | ||
return max_idx_pos[1:], max_idx_tok[1:], max_idx_tok_value[1:] | ||
|
||
|
||
def kl_divergence_attention(y_true, y_pred): | ||
|
||
# Compute log probabilities for KL divergence | ||
log_y_true = torch.log(y_true + 1e-10) | ||
log_y_pred = torch.log(y_pred + 1e-10) | ||
|
||
return y_true * (log_y_true - log_y_pred) | ||
|
||
|
||
def eval_prompt( | ||
prompt: List, | ||
model: HookedTransformer, | ||
sparse_autoencoder: Optional[SparseAutoencoder] = None, | ||
head_idx_override: Optional[int] = None,): | ||
''' | ||
Takes a list of strings as input. | ||
''' | ||
tokens = model.to_tokens(prompt) | ||
# tokens = tokens[:, :MAX_PROMPT_LEN] | ||
token_df = make_token_df(model, tokens[:,1:], len_suffix=5, len_prefix=10) | ||
|
||
# tokens = t.stack(tokens).to(device) | ||
layer_idx = sparse_autoencoder.cfg.hook_point_layer | ||
head_idx = sparse_autoencoder.cfg.hook_point_head_index if head_idx_override is None else head_idx_override | ||
head_hook_query_name = utils.get_act_name("q", layer_idx) | ||
head_hook_result_name = utils.get_act_name("z", layer_idx) | ||
head_hook_resid_name = utils.get_act_name("resid_pre", layer_idx) | ||
|
||
# Basic Forward Pass | ||
(original_logits, original_loss), original_cache = model.run_with_cache(tokens, return_type="both", loss_per_token=True) | ||
token_df['loss'] = original_loss.flatten().tolist() | ||
|
||
## Collect ATTN Results | ||
max_idx_pos, max_idx_tok, max_idx_tok_value = get_max_attn_token(tokens, original_cache, model, layer_idx, head_idx) | ||
token_df['max_idx_pos'] = max_idx_pos.flatten().tolist() | ||
token_df['max_idx_tok'] = max_idx_tok | ||
token_df['max_idx_tok_value'] = max_idx_tok_value.flatten().tolist() | ||
|
||
|
||
# Full Head Ablation | ||
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)): | ||
assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}" | ||
assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name) | ||
head_output[:, :, head[1], :] = 0 | ||
return head_output | ||
|
||
hook_to_ablate_head = partial(hook_to_ablate_head, head=(layer_idx, head_idx)) | ||
ablated_logits, ablated_loss = model.run_with_hooks(tokens, return_type="both", loss_per_token=True, fwd_hooks=[(head_hook_result_name, hook_to_ablate_head)]) | ||
|
||
logit_diff = original_logits - ablated_logits | ||
top10_token_suppression_vals, top10_token_suppression_inds = torch.topk(logit_diff, 10, dim=-1, largest=False) | ||
token_df['top10_token_suppression_diffs'] = top10_token_suppression_vals.flatten(0,1).tolist()[1:] | ||
decoded_tokens = [[model.tokenizer.decode(tok_id, skip_special_tokens=True) for tok_id in sequence] for sequence in top10_token_suppression_inds[0]] | ||
token_df['top10_token_suppression_inds'] = decoded_tokens[1:] | ||
|
||
top10_token_boosting_vals, top10_token_boosting_inds = torch.topk(logit_diff, 10, dim=-1, largest=True) | ||
token_df['top10_token_boosting_vals'] = top10_token_boosting_vals.flatten(0,1).tolist()[1:] | ||
decoded_tokens = [[model.tokenizer.decode(tok_id, skip_special_tokens=True) for tok_id in sequence] for sequence in top10_token_boosting_inds[0]] | ||
token_df['top10_token_boosting_inds'] = decoded_tokens[1:] | ||
|
||
|
||
token_df['ablated_loss'] = ablated_loss.flatten().tolist() | ||
token_df["loss_diff"] = token_df["ablated_loss"] - token_df["loss"] | ||
|
||
if sparse_autoencoder is not None: | ||
# Reconstruction of Query with SAE | ||
if "resid_pre" in sparse_autoencoder.cfg.hook_point: | ||
original_act = original_cache[sparse_autoencoder.cfg.hook_point] | ||
# token_df["q_norm"] = torch.norm(original_act, dim=-1)[:,1:].flatten().tolist() | ||
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act) | ||
# token_df["rec_q_norm"] = torch.norm(sae_out, dim=-1)[:,1:].flatten().tolist() | ||
|
||
# need to generate query | ||
def replacement_hook(resid_pre, hook, new_resid_pre=sae_out): | ||
return new_resid_pre | ||
|
||
with model.hooks(fwd_hooks=[(head_hook_resid_name, replacement_hook)]): | ||
_, resid_pre_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True) | ||
sae_out = resid_pre_cache[head_hook_query_name][:,:,head_idx] | ||
|
||
original_act = original_cache[head_hook_query_name][:,:,head_idx] | ||
per_tok_mse_loss = (sae_out.float() - original_act.float()).pow(2).sum(-1) | ||
total_variance = original_act.pow(2).sum(-1) | ||
explained_variance = per_tok_mse_loss/total_variance | ||
|
||
else: | ||
original_act = original_cache[sparse_autoencoder.cfg.hook_point][:,:,head_idx] | ||
token_df["q_norm"] = torch.norm(original_act, dim=-1)[:,1:].flatten().tolist() | ||
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_cache[sparse_autoencoder.cfg.hook_point][:,:,head_idx]) | ||
token_df["rec_q_norm"] = torch.norm(sae_out, dim=-1)[:,1:].flatten().tolist() | ||
# norm_ratio = torch.norm(original_act, dim=-1)/ torch.norm(sae_out, dim=-1) | ||
|
||
per_tok_mse_loss = (sae_out.float() - original_act.float()).pow(2).sum(-1) | ||
total_variance = original_act.pow(2).sum(-1) | ||
explained_variance = per_tok_mse_loss/total_variance | ||
|
||
num_active_features = (feature_acts > 0).sum(dim=-1) | ||
top_feature_acts, top_features = torch.topk(feature_acts, k = 10, dim = -1) | ||
|
||
# SAE Metrics | ||
token_df['mse_loss'] = per_tok_mse_loss.flatten()[1:].tolist() | ||
token_df['explained_variance'] = explained_variance.flatten()[1:].tolist() | ||
token_df['num_active_features'] = num_active_features.flatten()[1:].tolist() | ||
token_df['top_k_feature_acts'] = top_feature_acts.flatten(0,1).tolist()[1:] | ||
token_df['top_k_features'] = top_features.flatten(0,1).tolist()[1:] | ||
|
||
# Reconstruct Query | ||
def hook_to_reconstruct_query( | ||
head_input: Float[Tensor, "batch seq_len head_idx d_head"], | ||
hook: HookPoint, | ||
head, | ||
reconstructed_query: Float[Tensor, "batch seq_len d_model"] = None,): | ||
assert head[0] == hook.layer() | ||
head_input[:, :, head[1], :] = reconstructed_query[:, :] | ||
return head_input | ||
|
||
|
||
hook_fn = partial(hook_to_reconstruct_query, reconstructed_query=sae_out, head = (layer_idx, head_idx)) | ||
with model.hooks(fwd_hooks=[(head_hook_query_name, hook_fn)]): | ||
_, cache_reconstructed_query = model.run_with_cache(tokens, return_type="loss", loss_per_token=True) | ||
max_idx_pos, max_idx_tok, max_idx_tok_value = get_max_attn_token(tokens, cache_reconstructed_query, model, layer_idx, head_idx) | ||
|
||
# Get the KL Divergence of the attention distributions | ||
patterns_original = original_cache[utils.get_act_name("pattern", layer_idx)][0,head_idx].detach().cpu() | ||
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", layer_idx)][0,head_idx].detach().cpu() | ||
kl_result = kl_divergence_attention(patterns_original, patterns_reconstructed) | ||
kl_result = kl_result.sum(dim=-1)[1:].numpy() | ||
|
||
token_df['rec_q_max_idx_pos'] = max_idx_pos.flatten().tolist() | ||
token_df['rec_q_max_idx_tok'] = max_idx_tok | ||
token_df['rec_q_max_idx_tok_value'] = max_idx_tok_value.flatten().tolist() | ||
token_df['kl_divergence'] = kl_result.flatten().tolist() | ||
|
||
# add results to dataframe | ||
|
||
|
||
# print(feature_acts.shape) | ||
# token_df["ids_active_features"] = (feature_acts[0,1:] > 0) | ||
else: | ||
cache_reconstructed_query = None | ||
|
||
return token_df, original_cache, cache_reconstructed_query, feature_acts.flatten(0,1)[1:] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import load_dataset | ||
|
||
LENGTH_RANDOM_TOKS = 4 | ||
TOKEN_OF_INTEREST = " John" | ||
N_REPEAT_TOKENS = 3 | ||
|
||
def generate_random_token_prompt(model, n_random_tokens = 10, n_repeat_tokens = 3, token_of_interest: str = " John"): | ||
|
||
random_tokens = torch.randint(0, model.tokenizer.vocab_size, (n_random_tokens,)).to(model.cfg.device) | ||
# append the token id for " John" | ||
if token_of_interest is not None: | ||
john_token = torch.tensor(model.to_single_token(token_of_interest)).unsqueeze(0).to(model.cfg.device) | ||
random_tokens = torch.cat([john_token, random_tokens], dim=0) | ||
|
||
# repeat the tokens | ||
random_tokens = random_tokens.repeat(n_repeat_tokens) | ||
|
||
# generate an index for each group of tokens | ||
random_token_groups = torch.arange(0, n_repeat_tokens).unsqueeze(-1).repeat(1, LENGTH_RANDOM_TOKS+1).flatten() | ||
|
||
return random_tokens, random_token_groups | ||
|
||
|
||
def get_webtext(seed: int = 420, dataset="stas/openwebtext-10k") -> List[str]: | ||
"""Get 10,000 sentences from the OpenWebText dataset""" | ||
|
||
# Let's see some WEBTEXT | ||
raw_dataset = load_dataset(dataset) | ||
train_dataset = raw_dataset["train"] | ||
dataset = [train_dataset[i]["text"] for i in range(len(train_dataset))] | ||
|
||
# Shuffle the dataset (I don't want the Hitler thing being first so use a seeded shuffle) | ||
np.random.seed(seed) | ||
np.random.shuffle(dataset) | ||
|
||
return dataset | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import io | ||
import pickle | ||
|
||
import pandas as pd | ||
import torch | ||
from transformer_lens import HookedTransformer | ||
|
||
|
||
def list_flatten(nested_list): | ||
return [x for y in nested_list for x in y] | ||
|
||
def make_token_df(model:HookedTransformer, tokens, len_prefix=5, len_suffix=1): | ||
|
||
str_tokens = [model.to_str_tokens(t) for t in tokens] | ||
unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens] | ||
|
||
context = [] | ||
batch = [] | ||
pos = [] | ||
label = [] | ||
for b in range(tokens.shape[0]): | ||
# context.append([]) | ||
# batch.append([]) | ||
# pos.append([]) | ||
# label.append([]) | ||
for p in range(tokens.shape[1]): | ||
prefix = "".join(str_tokens[b][max(0, p-len_prefix):p]) | ||
if p==tokens.shape[1]-1: | ||
suffix = "" | ||
else: | ||
suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)]) | ||
current = str_tokens[b][p] | ||
context.append(f"{prefix}|{current}|{suffix}") | ||
batch.append(b) | ||
pos.append(p) | ||
label.append(f"{b}/{p}") | ||
# print(len(batch), len(pos), len(context), len(label)) | ||
return pd.DataFrame(dict( | ||
str_tokens=list_flatten(str_tokens), | ||
unique_token=list_flatten(unique_token), | ||
context=context, | ||
batch=batch, | ||
pos=pos, | ||
label=label, | ||
)) | ||
|
||
|
||
|
||
class CPU_Unpickler(pickle.Unpickler): | ||
def find_class(self, module, name): | ||
if module == 'torch.storage' and name == '_load_from_bytes': | ||
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | ||
else: return super().find_class(module, name) |
Oops, something went wrong.