Skip to content

Commit

Permalink
Fix checkpointing of training state that includes a compiled SAE (#143)
Browse files Browse the repository at this point in the history
* Adds state_dict to L1Scheduler

* investigating test failure

* fix: Fix issues with resumption testing (#144)

* fix always-true comparison in train context testing

* set default warmup steps to zero

* remove unused type attribute from L1Scheduler

* update training tests to use real context builder

* add docstring for build_train_ctx

* 2.1.2

Automatically generated by python-semantic-release

* Adds state_dict to L1Scheduler

* investigating test failure

---------

Co-authored-by: github-actions <github-actions@github.com>
  • Loading branch information
tomMcGrath and github-actions authored May 15, 2024
1 parent 448d911 commit 2f8c4e1
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sae_lens/training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Took the LR scheduler from my previous work: /~https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
"""

from typing import Any

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

Expand Down Expand Up @@ -140,3 +142,17 @@ def step(self):
self.sparse_autoencoder.l1_coefficient = self.final_l1_value # type: ignore

self.current_step += 1

def state_dict(self):
"""State dict for serializing as part of an SAETrainContext."""
return {
"l1_warmup_steps": self.l1_warmup_steps,
"total_steps": self.total_steps,
"final_l1_value": self.final_l1_value,
"current_step": self.current_step,
}

def load_state_dict(self, state_dict: dict[str, Any]):
"""Loads all state apart from attached SAE."""
for k in state_dict:
setattr(self, k, state_dict[k])

0 comments on commit 2f8c4e1

Please sign in to comment.