Skip to content

Commit

Permalink
Merge branch 'main' into type-checking
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Feb 29, 2024
2 parents 351995c + 3e78bce commit 57c4582
Show file tree
Hide file tree
Showing 20 changed files with 938 additions and 6,655 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./sae_training/geom_median/
exclude = ./sae_training/geom_median/, ./wandb/*, ./research/wandb/*
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ repos:
- id: check-added-large-files
args: [--maxkb=250000]
- repo: /~https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
- repo: /~https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--config=.flake8']
additional_dependencies: [
'flake8-blind-except',
'flake8-docstrings',
# 'flake8-docstrings',
'flake8-bugbear',
'flake8-comprehensions',
'flake8-docstrings',
'flake8-implicit-str-concat',
'pydocstyle>=5.0.0',
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ipykernel = "^6.29.2"
matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "/~https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -48,4 +49,4 @@ reportUnknownLambdaType = "none"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
build-backend = "poetry.core.masonry.api"
7 changes: 6 additions & 1 deletion sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def get_dashboard_folder_name(self):
def init_sae_session(self):
(
self.model,
self.sparse_autoencoder,
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
self.sparse_autoencoder = sae_group.autoencoders[0]

def get_tokens(
self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6
Expand Down Expand Up @@ -179,9 +180,11 @@ def get_index_to_resume_from(self):
if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"):
break

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
n_features = self.sparse_autoencoder.cfg.d_sae
n_features_at_a_time = self.n_features_at_a_time
id_of_last_feature_without_dashboard = i
assert self.final_index is not None # keep pyright happy
n_features_remaining = self.final_index - id_of_last_feature_without_dashboard
n_batches_to_do = n_features_remaining // n_features_at_a_time
if self.final_index == n_features:
Expand Down Expand Up @@ -214,6 +217,7 @@ def get_feature_property_df(self):
)
d_e_projection = cosine_similarity(W_dec_normalized, W_enc_normalized.T)

assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
temp_df = pd.DataFrame(
{
"log_feature_sparsity": feature_sparsity + 1e-10,
Expand Down Expand Up @@ -299,6 +303,7 @@ def run(self):
)
wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))})

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
self.n_features = self.sparse_autoencoder.cfg.d_sae
id_to_start_from = self.get_index_to_resume_from()
id_to_end_at = self.n_features if self.final_index is None else self.final_index
Expand Down
1 change: 1 addition & 0 deletions sae_analysis/visualizer/data_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ def hook_fn_resid_post(
# ! If verbose, try to estimate time it will take to generate data for all features, plus storage space

if verbose:
assert encoder.cfg.d_sae is not None # keep pyright happy
n_feats_total = encoder.cfg.d_sae

# Get time
Expand Down
111 changes: 66 additions & 45 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def __init__(
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)

# check if it's tokenized
if "tokens" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = True
print("Dataset is tokenized! Updating config.")
elif "text" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")
# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)
self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys()
print(
f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations:
if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
assert self.cfg.cached_activations_path is not None # keep pyright happy
# Sanity check: does the cache directory exist?
assert os.path.exists(
Expand Down Expand Up @@ -146,39 +146,65 @@ def get_batch_tokens(self):
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
act_name = self.cfg.hook_point
hook_point_layer = self.cfg.hook_point_layer
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
layers = (
self.cfg.hook_point_layer
if isinstance(self.cfg.hook_point_layer, list)
else [self.cfg.hook_point_layer]
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name][:, :, self.cfg.hook_point_head_index]
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name]
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name] for act_name in act_names
]

# Stack along a new dimension to keep separate layers distinct
stacked_activations = torch.stack(activations_list, dim=2)

return activations
return stacked_activations

def get_buffer(self, n_batches_in_buffer: int):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = (
len(self.cfg.hook_point_layer)
if isinstance(self.cfg.hook_point_layer, list)
else 1
) # Number of hook points or layers

if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor (flattened along all dims except d_in)
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, d_in), dtype=self.cfg.dtype, device=self.cfg.device
(buffer_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)
n_tokens_filled = 0

# The activations may be split across multiple files,
# Or we might only want a subset of one file (depending on the sizes)
# Assume activations for different layers are stored separately and need to be combined
while n_tokens_filled < buffer_size:
# Load the next file
# Make sure it exists
if not os.path.exists(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
):
Expand All @@ -194,55 +220,49 @@ def get_buffer(self, n_batches_in_buffer: int):
)
print(f"Returning a buffer of size {n_tokens_filled} instead.")
print("\n\n")
new_buffer = new_buffer[:n_tokens_filled]
break
new_buffer = new_buffer[:n_tokens_filled, ...]
return new_buffer

activations = torch.load(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
)

# If we only want a subset of the file, take it
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
activations = activations[: buffer_size - n_tokens_filled]
activations = activations[: buffer_size - n_tokens_filled, ...]
taking_subset_of_file = True

# Add it to the buffer
new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = (
activations
)
new_buffer[
n_tokens_filled : n_tokens_filled + activations.shape[0], ...
] = activations

# Update counters
n_tokens_filled += activations.shape[0]
if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
else:
self.next_cache_idx += 1
self.next_idx_within_buffer = 0

n_tokens_filled += activations.shape[0]

return new_buffer

refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# refill_iterator = tqdm(refill_iterator, desc="generate activations")

# Initialize empty tensor buffer of the maximum required size
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, d_in),
(total_size, context_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)

# Insert activations directly into pre-allocated buffer
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[refill_batch_idx_start : refill_batch_idx_start + batch_size] = (
refill_activations
)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
] = refill_activations

# pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer.reshape(-1, num_layers, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

return new_buffer
Expand All @@ -262,7 +282,8 @@ def get_data_loader(

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer]
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer],
dim=0,
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
Expand Down
3 changes: 2 additions & 1 deletion sae_training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
)
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer)
for i in tqdm(range(n_buffers), desc="Caching activations"):
# for i in tqdm(range(n_buffers), desc="Caching activations"):
for i in range(n_buffers):
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer)
torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt")
del buffer
Expand Down
7 changes: 5 additions & 2 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RunnerConfig(ABC):

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point: str = "blocks.{layer}.hook_mlp_out"
hook_point_layer: int = 0
hook_point_head_index: Optional[int] = None
dataset_path: str = "NeelNanda/c4-tokenized-2b"
Expand Down Expand Up @@ -56,9 +56,11 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
b_dec_init_method: str = "geometric_median"
expansion_factor: int = 4
from_pretrained_path: Optional[str] = None
d_sae: Optional[int] = None

# Training Parameters
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
Expand Down Expand Up @@ -86,7 +88,8 @@ class LanguageModelSAERunnerConfig(RunnerConfig):

def __post_init__(self):
super().__post_init__()
self.d_sae = self.d_in * self.expansion_factor
if not isinstance(self.expansion_factor, list):
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = (
self.train_batch_size * self.context_size * self.n_batches_in_buffer
)
Expand Down
17 changes: 9 additions & 8 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_evals(
activation_store: ActivationsStore,
model: HookedTransformer,
n_training_steps: int,
suffix: str = "",
):
hook_point = sparse_autoencoder.cfg.hook_point
hook_point_layer = sparse_autoencoder.cfg.hook_point_layer
Expand Down Expand Up @@ -72,13 +73,13 @@ def run_evals(
wandb.log(
{
# l2 norms
"metrics/l2_norm": l2_norm_out.mean().item(),
"metrics/l2_ratio": l2_norm_ratio.mean().item(),
f"metrics/l2_norm{suffix}": l2_norm_out.mean().item(),
f"metrics/l2_ratio{suffix}": l2_norm_ratio.mean().item(),
# CE Loss
"metrics/CE_loss_score": recons_score,
"metrics/ce_loss_without_sae": ntp_loss,
"metrics/ce_loss_with_sae": recons_loss,
"metrics/ce_loss_with_ablation": zero_abl_loss,
f"metrics/CE_loss_score{suffix}": recons_score,
f"metrics/ce_loss_without_sae{suffix}": ntp_loss,
f"metrics/ce_loss_with_sae{suffix}": recons_loss,
f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss,
},
step=n_training_steps,
)
Expand Down Expand Up @@ -143,8 +144,8 @@ def head_replacement_hook(activations: torch.Tensor, hook: Any):
if wandb.run is not None:
wandb.log(
{
"metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(),
"metrics/kldiv_ablation": kl_result_ablation.mean().item(),
f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(),
f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(),
},
step=n_training_steps,
)
Expand Down
Loading

0 comments on commit 57c4582

Please sign in to comment.