diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 6964ccfe..662f0c5e 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -39,7 +39,7 @@ def get_batch_tokens(self): context_size = self.cfg.context_size device = self.cfg.device - batch_tokens = torch.LongTensor(size=(0, context_size)).to(device) + batch_tokens = torch.zeros(size=(0, context_size), device=device, dtype=torch.long, requires_grad=False) current_batch = [] current_length = 0 @@ -59,10 +59,13 @@ def get_batch_tokens(self): next(self.iterable_dataset)["tokens"], dtype=torch.long, device=device, + requires_grad=False, ) token_len = tokens.shape[0] # TODO: Fix this so that we are limiting how many tokens we get from the same context. + + bos_token_id_tensor = torch.tensor([self.model.tokenizer.bos_token_id], device=tokens.device, dtype=torch.long) while token_len > 0 and batch_tokens.shape[0] < batch_size: # Space left in the current batch space_left = context_size - current_length @@ -81,9 +84,7 @@ def get_batch_tokens(self): tokens = tokens[space_left:] tokens = torch.cat( ( - torch.LongTensor([self.model.tokenizer.bos_token_id]).to( - tokens.device - ), + bos_token_id_tensor, tokens, ), dim=0, @@ -150,7 +151,7 @@ def get_buffer(self, n_batches_in_buffer): # pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer") for refill_batch_idx_start in refill_iterator: refill_batch_tokens = self.get_batch_tokens() - refill_activations = self.get_activations(refill_batch_tokens).to(self.cfg.device) + refill_activations = self.get_activations(refill_batch_tokens) new_buffer[ refill_batch_idx_start : refill_batch_idx_start + batch_size ] = refill_activations diff --git a/sae_training/config.py b/sae_training/config.py index 8d95d0dd..e9919c5d 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -65,6 +65,8 @@ def __post_init__(self): if self.feature_sampling_method not in [None, "l2"]: raise ValueError(f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}") + self.device = torch.device(self.device) + unique_id = wandb.util.generate_id() self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" @@ -79,4 +81,5 @@ def __post_init__(self): # how many times will we sample dead neurons? n_dead_feature_samples = total_training_steps // self.dead_feature_window - 1 - print(f"n_dead_feature_samples: {n_dead_feature_samples}") \ No newline at end of file + print(f"n_dead_feature_samples: {n_dead_feature_samples}") + \ No newline at end of file