Skip to content

Commit

Permalink
minor speed improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 12, 2023
1 parent c0eac0a commit f7ea316
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 6 additions & 5 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_batch_tokens(self):
context_size = self.cfg.context_size
device = self.cfg.device

batch_tokens = torch.LongTensor(size=(0, context_size)).to(device)
batch_tokens = torch.zeros(size=(0, context_size), device=device, dtype=torch.long, requires_grad=False)

current_batch = []
current_length = 0
Expand All @@ -59,10 +59,13 @@ def get_batch_tokens(self):
next(self.iterable_dataset)["tokens"],
dtype=torch.long,
device=device,
requires_grad=False,
)
token_len = tokens.shape[0]

# TODO: Fix this so that we are limiting how many tokens we get from the same context.

bos_token_id_tensor = torch.tensor([self.model.tokenizer.bos_token_id], device=tokens.device, dtype=torch.long)
while token_len > 0 and batch_tokens.shape[0] < batch_size:
# Space left in the current batch
space_left = context_size - current_length
Expand All @@ -81,9 +84,7 @@ def get_batch_tokens(self):
tokens = tokens[space_left:]
tokens = torch.cat(
(
torch.LongTensor([self.model.tokenizer.bos_token_id]).to(
tokens.device
),
bos_token_id_tensor,
tokens,
),
dim=0,
Expand Down Expand Up @@ -150,7 +151,7 @@ def get_buffer(self, n_batches_in_buffer):
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens).to(self.cfg.device)
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations
Expand Down
5 changes: 4 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __post_init__(self):
if self.feature_sampling_method not in [None, "l2"]:
raise ValueError(f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}")

self.device = torch.device(self.device)

unique_id = wandb.util.generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

Expand All @@ -79,4 +81,5 @@ def __post_init__(self):

# how many times will we sample dead neurons?
n_dead_feature_samples = total_training_steps // self.dead_feature_window - 1
print(f"n_dead_feature_samples: {n_dead_feature_samples}")
print(f"n_dead_feature_samples: {n_dead_feature_samples}")

0 comments on commit f7ea316

Please sign in to comment.