Skip to content

Commit

Permalink
Merge pull request #54 from jbloomAus/hook_z_suppourt
Browse files Browse the repository at this point in the history
notional support, needs more thorough testing
  • Loading branch information
jbloomAus authored Mar 27, 2024
2 parents 8ac8f05 + 9585022 commit 277f35b
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 247 deletions.
7 changes: 7 additions & 0 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def get_batch_tokens(self):
def get_activations(self, batch_tokens: torch.Tensor):
"""
Returns activations of shape (batches, context, num_layers, d_in)
d_in may result from a concatenated head dimension.
"""
layers = (
self.cfg.hook_point_layer
Expand All @@ -174,6 +176,11 @@ def get_activations(self, batch_tokens: torch.Tensor):
activations_list = [
act[:, :, self.cfg.hook_point_head_index] for act in activations_list
]
elif activations_list[0].ndim > 3: # if we have a head dimension
# flatten the head dimension
activations_list = [
act.view(act.shape[0], act.shape[1], -1) for act in activations_list
]

# Stack along a new dimension to keep separate layers distinct
stacked_activations = torch.stack(activations_list, dim=2)
Expand Down
194 changes: 110 additions & 84 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def run_evals(
)

# get act
if sparse_autoencoder.cfg.hook_point_head_index is not None:
original_act = cache[sparse_autoencoder.cfg.hook_point][
:, :, sparse_autoencoder.cfg.hook_point_head_index
]
if hook_point_head_index is not None:
original_act = cache[hook_point][:, :, hook_point_head_index]
elif "attn" in hook_point:
original_act = cache[hook_point].flatten(-2, -1)
else:
original_act = cache[sparse_autoencoder.cfg.hook_point]

sae_out, _feature_acts, _, _, _, _ = sparse_autoencoder(original_act)
patterns_original = (
cache[get_act_name("pattern", hook_point_layer)][:, hook_point_head_index]
.detach()
.cpu()
)
original_act = cache[hook_point]

sae_out, _, _, _, _, _ = sparse_autoencoder(original_act)
# patterns_original = (
# cache[get_act_name("pattern", hook_point_layer)][:, hook_point_head_index]
# .detach()
# .cpu()
# )
del cache

if "cuda" in str(model.cfg.device):
Expand All @@ -84,71 +84,83 @@ def run_evals(
step=n_training_steps,
)

head_index = sparse_autoencoder.cfg.hook_point_head_index

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
return activations

def head_replacement_hook(activations: torch.Tensor, hook: Any):
new_actions = sparse_autoencoder.forward(activations[:, :, head_index])[0].to(
activations.dtype
)
activations[:, :, head_index] = new_actions
return activations

head_index = sparse_autoencoder.cfg.hook_point_head_index
replacement_hook = (
standard_replacement_hook if head_index is None else head_replacement_hook
)
# head_index = sparse_autoencoder.cfg.hook_point_head_index

# def standard_replacement_hook(activations: torch.Tensor, hook: Any):
# activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
# return activations

# def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
# new_activations = sparse_autoencoder.forward(activations)[0].to(
# activations.dtype
# )
# activations = new_activations.reshape(
# activations.shape
# ) # reshape to match original shape
# return activations

# def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
# new_activations = sparse_autoencoder.forward(activations[:, :, head_index])[
# 0
# ].to(activations.dtype)
# activations[:, :, head_index] = new_activations
# return activations

# if "attn" in hook_point:
# if hook_point_head_index is None:
# replacement_hook = all_head_replacement_hook
# else:
# replacement_hook = single_head_replacement_hook
# else:
# replacement_hook = standard_replacement_hook

# get attn when using reconstructed activations
with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]):
_, new_cache = model.run_with_cache(
eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
)
patterns_reconstructed = (
new_cache[get_act_name("pattern", hook_point_layer)][
:, hook_point_head_index
]
.detach()
.cpu()
)
del new_cache

# get attn when using reconstructed activations
with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]):
_, zero_ablation_cache = model.run_with_cache(
eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
)
patterns_ablation = (
zero_ablation_cache[get_act_name("pattern", hook_point_layer)][
:, hook_point_head_index
]
.detach()
.cpu()
)
del zero_ablation_cache

if sparse_autoencoder.cfg.hook_point_head_index:
kl_result_reconstructed = kl_divergence_attention(
patterns_original, patterns_reconstructed
)
kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy()

kl_result_ablation = kl_divergence_attention(
patterns_original, patterns_ablation
)
kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy()

if wandb.run is not None:
wandb.log(
{
f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(),
f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(),
},
step=n_training_steps,
)
# with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]):
# _, new_cache = model.run_with_cache(
# eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
# )
# patterns_reconstructed = (
# new_cache[get_act_name("pattern", hook_point_layer)][
# :, hook_point_head_index
# ]
# .detach()
# .cpu()
# )
# del new_cache

# # get attn when using reconstructed activations
# with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]):
# _, zero_ablation_cache = model.run_with_cache(
# eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
# )
# patterns_ablation = (
# zero_ablation_cache[get_act_name("pattern", hook_point_layer)][
# :, hook_point_head_index
# ]
# .detach()
# .cpu()
# )
# del zero_ablation_cache

# if sparse_autoencoder.cfg.hook_point_head_index:
# kl_result_reconstructed = kl_divergence_attention(
# patterns_original, patterns_reconstructed
# )
# kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy()

# kl_result_ablation = kl_divergence_attention(
# patterns_original, patterns_ablation
# )
# kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy()

# if wandb.run is not None:
# wandb.log(
# {
# f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(),
# f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(),
# },
# step=n_training_steps,
# )


def recons_loss_batched(
Expand Down Expand Up @@ -187,22 +199,36 @@ def get_recons_loss(
):
hook_point = sparse_autoencoder.cfg.hook_point
loss = model(batch_tokens, return_type="loss")
head_index = sparse_autoencoder.cfg.hook_point_head_index
hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
return activations

def head_replacement_hook(activations: torch.Tensor, hook: Any):
new_activations = sparse_autoencoder.forward(activations[:, :, head_index])[
0
].to(activations.dtype)
activations[:, :, head_index] = new_activations
def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
new_activations = sparse_autoencoder.forward(activations.flatten(-2, -1))[0].to(
activations.dtype
)
new_activations = new_activations.reshape(
activations.shape
) # reshape to match original shape
return new_activations

def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
new_activations = sparse_autoencoder.forward(
activations[:, :, hook_point_head_index]
)[0].to(activations.dtype)
activations[:, :, hook_point_head_index] = new_activations
return activations

replacement_hook = (
standard_replacement_hook if head_index is None else head_replacement_hook
)
if "attn" in hook_point:
if hook_point_head_index is None:
replacement_hook = all_head_replacement_hook
else:
replacement_hook = single_head_replacement_hook
else:
replacement_hook = standard_replacement_hook

recons_loss = model.run_with_hooks(
batch_tokens,
return_type="loss",
Expand Down
Loading

0 comments on commit 277f35b

Please sign in to comment.