Skip to content

Commit

Permalink
reset feature sparsity calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 31, 2023
1 parent 048d267 commit 4c7f6f2
Show file tree
Hide file tree
Showing 44 changed files with 3,206,120 additions and 1,394 deletions.
Binary file added research/.DS_Store
Binary file not shown.
809 changes: 0 additions & 809 deletions research/copy_suppression_sae_investigation.ipynb

This file was deleted.

31,095 changes: 31,095 additions & 0 deletions research/gpt2-small-features/data_12477.html

Large diffs are not rendered by default.

28,367 changes: 28,367 additions & 0 deletions research/gpt2-small-features/data_17904.html

Large diffs are not rendered by default.

30,413 changes: 30,413 additions & 0 deletions research/gpt2-small-features/data_24529.html

Large diffs are not rendered by default.

31,095 changes: 31,095 additions & 0 deletions research/gpt2-small-features/data_25633.html

Large diffs are not rendered by default.

29,390 changes: 29,390 additions & 0 deletions research/gpt2-small-features/data_29561.html

Large diffs are not rendered by default.

31,095 changes: 31,095 additions & 0 deletions research/gpt2-small-features/data_31228.html

Large diffs are not rendered by default.

28,026 changes: 28,026 additions & 0 deletions research/gpt2-small-features/data_5801.html

Large diffs are not rendered by default.

30,072 changes: 30,072 additions & 0 deletions research/gpt2-small-features/data_7304.html

Large diffs are not rendered by default.

30,754 changes: 30,754 additions & 0 deletions research/gpt2-small-features/data_8556.html

Large diffs are not rendered by default.

30,754 changes: 30,754 additions & 0 deletions research/gpt2-small-features/data_9035.html

Large diffs are not rendered by default.

Empty file added research/joseph/__init__.py
Empty file.
190 changes: 190 additions & 0 deletions research/joseph/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import os
import warnings
from functools import partial
from typing import List, Optional, Tuple

import pandas as pd
import torch
from jaxtyping import Float
from torch import Tensor
from transformer_lens import HookedTransformer, utils
from transformer_lens.components import HookPoint

from research.joseph.utils import make_token_df
from sae_training.sparse_autoencoder import SparseAutoencoder

os.environ["TOKENIZERS_PARALLELISM"] = "false"


warnings.filterwarnings("ignore")



# HOOKS
# LAYER_IDX, HEAD_IDX = (10, 7)
LAYER_IDX, HEAD_IDX = (1, 11)
# W_U = model.W_U.clone()
# HEAD_HOOK_RESULT_NAME = utils.get_act_name("result", LAYER_IDX)
HEAD_HOOK_RESULT_NAME = utils.get_act_name("z", LAYER_IDX)
HEAD_HOOK_QUERY_NAME = utils.get_act_name("q", LAYER_IDX)
HEAD_HOOK_RESID_NAME = utils.get_act_name("resid_pre", LAYER_IDX)
# BATCH_SIZE = 10




# Get ATTN results

def get_max_attn_token(tokens, cache, model: HookedTransformer, LAYER_IDX, HEAD_IDX):
tokens = tokens.to("cpu")
pattern_name = utils.get_act_name("pattern", LAYER_IDX)
pattern = cache[pattern_name][0,HEAD_IDX].detach().cpu()
max_idx_pos = pattern.argmax(dim=-1)
max_idx_token_id = torch.gather(tokens, dim=-1, index=max_idx_pos.unsqueeze(-1).T)
max_idx_tok = model.to_string(max_idx_token_id.T)
max_idx_tok_value = pattern.max(dim=1).values
return max_idx_pos[1:], max_idx_tok[1:], max_idx_tok_value[1:]


def kl_divergence_attention(y_true, y_pred):

# Compute log probabilities for KL divergence
log_y_true = torch.log(y_true + 1e-10)
log_y_pred = torch.log(y_pred + 1e-10)

return y_true * (log_y_true - log_y_pred)


def eval_prompt(
prompt: List,
model: HookedTransformer,
sparse_autoencoder: Optional[SparseAutoencoder] = None,
head_idx_override: Optional[int] = None,):
'''
Takes a list of strings as input.
'''
tokens = model.to_tokens(prompt)
# tokens = tokens[:, :MAX_PROMPT_LEN]
token_df = make_token_df(model, tokens[:,1:], len_suffix=5, len_prefix=10)

# tokens = t.stack(tokens).to(device)
layer_idx = sparse_autoencoder.cfg.hook_point_layer
head_idx = sparse_autoencoder.cfg.hook_point_head_index if head_idx_override is None else head_idx_override
head_hook_query_name = utils.get_act_name("q", layer_idx)
head_hook_result_name = utils.get_act_name("z", layer_idx)
head_hook_resid_name = utils.get_act_name("resid_pre", layer_idx)

# Basic Forward Pass
(original_logits, original_loss), original_cache = model.run_with_cache(tokens, return_type="both", loss_per_token=True)
token_df['loss'] = original_loss.flatten().tolist()

## Collect ATTN Results
max_idx_pos, max_idx_tok, max_idx_tok_value = get_max_attn_token(tokens, original_cache, model, layer_idx, head_idx)
token_df['max_idx_pos'] = max_idx_pos.flatten().tolist()
token_df['max_idx_tok'] = max_idx_tok
token_df['max_idx_tok_value'] = max_idx_tok_value.flatten().tolist()


# Full Head Ablation
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)):
assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}"
assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name)
head_output[:, :, head[1], :] = 0
return head_output

hook_to_ablate_head = partial(hook_to_ablate_head, head=(layer_idx, head_idx))
ablated_logits, ablated_loss = model.run_with_hooks(tokens, return_type="both", loss_per_token=True, fwd_hooks=[(head_hook_result_name, hook_to_ablate_head)])

logit_diff = original_logits - ablated_logits
top10_token_suppression_vals, top10_token_suppression_inds = torch.topk(logit_diff, 10, dim=-1, largest=False)
token_df['top10_token_suppression_diffs'] = top10_token_suppression_vals.flatten(0,1).tolist()[1:]
decoded_tokens = [[model.tokenizer.decode(tok_id, skip_special_tokens=True) for tok_id in sequence] for sequence in top10_token_suppression_inds[0]]
token_df['top10_token_suppression_inds'] = decoded_tokens[1:]

top10_token_boosting_vals, top10_token_boosting_inds = torch.topk(logit_diff, 10, dim=-1, largest=True)
token_df['top10_token_boosting_vals'] = top10_token_boosting_vals.flatten(0,1).tolist()[1:]
decoded_tokens = [[model.tokenizer.decode(tok_id, skip_special_tokens=True) for tok_id in sequence] for sequence in top10_token_boosting_inds[0]]
token_df['top10_token_boosting_inds'] = decoded_tokens[1:]


token_df['ablated_loss'] = ablated_loss.flatten().tolist()
token_df["loss_diff"] = token_df["ablated_loss"] - token_df["loss"]

if sparse_autoencoder is not None:
# Reconstruction of Query with SAE
if "resid_pre" in sparse_autoencoder.cfg.hook_point:
original_act = original_cache[sparse_autoencoder.cfg.hook_point]
# token_df["q_norm"] = torch.norm(original_act, dim=-1)[:,1:].flatten().tolist()
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)
# token_df["rec_q_norm"] = torch.norm(sae_out, dim=-1)[:,1:].flatten().tolist()

# need to generate query
def replacement_hook(resid_pre, hook, new_resid_pre=sae_out):
return new_resid_pre

with model.hooks(fwd_hooks=[(head_hook_resid_name, replacement_hook)]):
_, resid_pre_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)
sae_out = resid_pre_cache[head_hook_query_name][:,:,head_idx]

original_act = original_cache[head_hook_query_name][:,:,head_idx]
per_tok_mse_loss = (sae_out.float() - original_act.float()).pow(2).sum(-1)
total_variance = original_act.pow(2).sum(-1)
explained_variance = per_tok_mse_loss/total_variance

else:
original_act = original_cache[sparse_autoencoder.cfg.hook_point][:,:,head_idx]
token_df["q_norm"] = torch.norm(original_act, dim=-1)[:,1:].flatten().tolist()
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_cache[sparse_autoencoder.cfg.hook_point][:,:,head_idx])
token_df["rec_q_norm"] = torch.norm(sae_out, dim=-1)[:,1:].flatten().tolist()
# norm_ratio = torch.norm(original_act, dim=-1)/ torch.norm(sae_out, dim=-1)

per_tok_mse_loss = (sae_out.float() - original_act.float()).pow(2).sum(-1)
total_variance = original_act.pow(2).sum(-1)
explained_variance = per_tok_mse_loss/total_variance

num_active_features = (feature_acts > 0).sum(dim=-1)
top_feature_acts, top_features = torch.topk(feature_acts, k = 10, dim = -1)

# SAE Metrics
token_df['mse_loss'] = per_tok_mse_loss.flatten()[1:].tolist()
token_df['explained_variance'] = explained_variance.flatten()[1:].tolist()
token_df['num_active_features'] = num_active_features.flatten()[1:].tolist()
token_df['top_k_feature_acts'] = top_feature_acts.flatten(0,1).tolist()[1:]
token_df['top_k_features'] = top_features.flatten(0,1).tolist()[1:]

# Reconstruct Query
def hook_to_reconstruct_query(
head_input: Float[Tensor, "batch seq_len head_idx d_head"],
hook: HookPoint,
head,
reconstructed_query: Float[Tensor, "batch seq_len d_model"] = None,):
assert head[0] == hook.layer()
head_input[:, :, head[1], :] = reconstructed_query[:, :]
return head_input


hook_fn = partial(hook_to_reconstruct_query, reconstructed_query=sae_out, head = (layer_idx, head_idx))
with model.hooks(fwd_hooks=[(head_hook_query_name, hook_fn)]):
_, cache_reconstructed_query = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)
max_idx_pos, max_idx_tok, max_idx_tok_value = get_max_attn_token(tokens, cache_reconstructed_query, model, layer_idx, head_idx)

# Get the KL Divergence of the attention distributions
patterns_original = original_cache[utils.get_act_name("pattern", layer_idx)][0,head_idx].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", layer_idx)][0,head_idx].detach().cpu()
kl_result = kl_divergence_attention(patterns_original, patterns_reconstructed)
kl_result = kl_result.sum(dim=-1)[1:].numpy()

token_df['rec_q_max_idx_pos'] = max_idx_pos.flatten().tolist()
token_df['rec_q_max_idx_tok'] = max_idx_tok
token_df['rec_q_max_idx_tok_value'] = max_idx_tok_value.flatten().tolist()
token_df['kl_divergence'] = kl_result.flatten().tolist()

# add results to dataframe


# print(feature_acts.shape)
# token_df["ids_active_features"] = (feature_acts[0,1:] > 0)
else:
cache_reconstructed_query = None

return token_df, original_cache, cache_reconstructed_query, feature_acts.flatten(0,1)[1:]
41 changes: 41 additions & 0 deletions research/joseph/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

import numpy as np
import torch
from datasets import load_dataset

LENGTH_RANDOM_TOKS = 4
TOKEN_OF_INTEREST = " John"
N_REPEAT_TOKENS = 3

def generate_random_token_prompt(model, n_random_tokens = 10, n_repeat_tokens = 3, token_of_interest: str = " John"):

random_tokens = torch.randint(0, model.tokenizer.vocab_size, (n_random_tokens,)).to(model.cfg.device)
# append the token id for " John"
if token_of_interest is not None:
john_token = torch.tensor(model.to_single_token(token_of_interest)).unsqueeze(0).to(model.cfg.device)
random_tokens = torch.cat([john_token, random_tokens], dim=0)

# repeat the tokens
random_tokens = random_tokens.repeat(n_repeat_tokens)

# generate an index for each group of tokens
random_token_groups = torch.arange(0, n_repeat_tokens).unsqueeze(-1).repeat(1, LENGTH_RANDOM_TOKS+1).flatten()

return random_tokens, random_token_groups


def get_webtext(seed: int = 420, dataset="stas/openwebtext-10k") -> List[str]:
"""Get 10,000 sentences from the OpenWebText dataset"""

# Let's see some WEBTEXT
raw_dataset = load_dataset(dataset)
train_dataset = raw_dataset["train"]
dataset = [train_dataset[i]["text"] for i in range(len(train_dataset))]

# Shuffle the dataset (I don't want the Hitler thing being first so use a seeded shuffle)
np.random.seed(seed)
np.random.shuffle(dataset)

return dataset

53 changes: 53 additions & 0 deletions research/joseph/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import io
import pickle

import pandas as pd
import torch
from transformer_lens import HookedTransformer


def list_flatten(nested_list):
return [x for y in nested_list for x in y]

def make_token_df(model:HookedTransformer, tokens, len_prefix=5, len_suffix=1):

str_tokens = [model.to_str_tokens(t) for t in tokens]
unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

context = []
batch = []
pos = []
label = []
for b in range(tokens.shape[0]):
# context.append([])
# batch.append([])
# pos.append([])
# label.append([])
for p in range(tokens.shape[1]):
prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
if p==tokens.shape[1]-1:
suffix = ""
else:
suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
current = str_tokens[b][p]
context.append(f"{prefix}|{current}|{suffix}")
batch.append(b)
pos.append(p)
label.append(f"{b}/{p}")
# print(len(batch), len(pos), len(context), len(label))
return pd.DataFrame(dict(
str_tokens=list_flatten(str_tokens),
unique_token=list_flatten(unique_token),
context=context,
batch=batch,
pos=pos,
label=label,
))



class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)
Loading

0 comments on commit 4c7f6f2

Please sign in to comment.