From fd5448c5a016a8134a881732f6c61cd7a37eab67 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 3 Apr 2024 00:23:08 +0100 Subject: [PATCH] adding lr_decay_steps and refactoring get_scheduler --- sae_training/config.py | 3 +- sae_training/optim.py | 124 ++++++++++-------- sae_training/train_sae_on_language_model.py | 1 + tests/unit/test_optim.py | 97 ++++++++------ .../unit/test_train_sae_on_language_model.py | 4 +- 5 files changed, 129 insertions(+), 100 deletions(-) diff --git a/sae_training/config.py b/sae_training/config.py index 9d3eb737..a6da855d 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -64,9 +64,10 @@ class LanguageModelSAERunnerConfig(RunnerConfig): lp_norm: float = 1 lr: float = 3e-4 lr_scheduler_name: str = ( - "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + "constant" # constant, cosineannealing, cosineannealingwarmrestarts ) lr_warm_up_steps: int = 500 + lr_decay_steps: int = 0 train_batch_size: int = 4096 # Resampling protocol args diff --git a/sae_training/optim.py b/sae_training/optim.py index 11b69a4f..4f3e9514 100644 --- a/sae_training/optim.py +++ b/sae_training/optim.py @@ -2,9 +2,6 @@ Took the LR scheduler from my previous work: /~https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425 """ -import math -from typing import Optional - import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler @@ -14,75 +11,88 @@ # Cosine Annealing with Warmup # Cosine Annealing with Warmup / Restarts def get_scheduler( - scheduler_name: Optional[str], + scheduler_name: str, optimizer: optim.Optimizer, + training_steps: int, + lr: float, warm_up_steps: int = 0, - training_steps: int | None = None, + decay_steps: int = 0, num_cycles: int = 1, lr_end: float = 0.0, -): +) -> lr_scheduler.LRScheduler: """ Loosely based on this, seemed simpler write this than import transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules Args: - scheduler_name (Optional[str]): Name of the scheduler to use. If None, returns a constant scheduler + scheduler_name (str): Name of the scheduler to use, one of "constant", "cosineannealing", "cosineannealingwarmrestarts" optimizer (optim.Optimizer): Optimizer to use - **kwargs: Additional arguments to pass to the scheduler including warm_up_steps, - training_steps, num_cycles, lr_end. + training_steps (int): Total number of training steps + warm_up_steps (int, optional): Number of linear warm up steps. Defaults to 0. + decay_steps (int, optional): Number of linear decay steps to 0. Defaults to 0. + num_cycles (int, optional): Number of cycles for cosine annealing with warm restarts. Defaults to 1. + lr_end (float, optional): Final learning rate multiplier before decay. Defaults to 0.0. """ + base_scheduler_steps = training_steps - warm_up_steps - decay_steps + norm_scheduler_name = scheduler_name.lower() + main_scheduler = _get_main_scheduler( + norm_scheduler_name, + optimizer, + steps=base_scheduler_steps, + lr_end=lr_end, + num_cycles=num_cycles, + ) + if norm_scheduler_name == "constant": + # constant scheduler ignores lr_end, so decay needs to start at lr + lr_end = lr + schedulers: list[lr_scheduler.LRScheduler] = [] + milestones: list[int] = [] + if warm_up_steps > 0: + schedulers.append( + lr_scheduler.LinearLR( + optimizer, + start_factor=1 / warm_up_steps, + end_factor=1.0, + total_iters=warm_up_steps - 1, + ), + ) + milestones.append(warm_up_steps) + schedulers.append(main_scheduler) + if decay_steps > 0: + if lr_end == 0.0: + raise ValueError( + "Cannot have decay_steps with lr_end=0.0, this would decay from 0 to 0 and be a waste." + ) + schedulers.append( + lr_scheduler.LinearLR( + optimizer, + start_factor=lr_end / lr, + end_factor=0.0, + total_iters=decay_steps, + ), + ) + milestones.append(training_steps - decay_steps) + return lr_scheduler.SequentialLR( + schedulers=schedulers, + optimizer=optimizer, + milestones=milestones, + ) - def get_warmup_lambda(warm_up_steps: int, training_steps: int): - - def lr_lambda(steps: int): - if steps < warm_up_steps: - return (steps + 1) / warm_up_steps - else: - return (training_steps - steps) / (training_steps - warm_up_steps) - - return lr_lambda - - # heavily derived from hugging face although copilot helped. - def get_warmup_cosine_lambda( - warm_up_steps: int, training_steps: int, lr_end: float - ): - - def lr_lambda(steps: int): - if steps < warm_up_steps: - return (steps + 1) / warm_up_steps - else: - progress = (steps - warm_up_steps) / (training_steps - warm_up_steps) - return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) - - return lr_lambda - if scheduler_name is None or scheduler_name.lower() == "constant": +def _get_main_scheduler( + scheduler_name: str, + optimizer: optim.Optimizer, + steps: int, + lr_end: float, + num_cycles: int, +) -> lr_scheduler.LRScheduler: + if scheduler_name == "constant": return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0) - elif scheduler_name.lower() == "constantwithwarmup": - return lr_scheduler.LambdaLR( - optimizer, - lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps), - ) - elif scheduler_name.lower() == "linearwarmupdecay": - assert training_steps is not None, "training_steps must be provided" - lr_lambda = get_warmup_lambda(warm_up_steps, training_steps) - return lr_scheduler.LambdaLR(optimizer, lr_lambda) - elif scheduler_name.lower() == "cosineannealing": - assert training_steps is not None, "training_steps must be provided" - return lr_scheduler.CosineAnnealingLR( - optimizer, T_max=training_steps, eta_min=lr_end - ) - elif scheduler_name.lower() == "cosineannealingwarmup": - assert training_steps is not None, "training_steps must be provided" - lr_lambda = get_warmup_cosine_lambda( - warm_up_steps, training_steps, lr_end=lr_end - ) - return lr_scheduler.LambdaLR(optimizer, lr_lambda) - elif scheduler_name.lower() == "cosineannealingwarmrestarts": - assert training_steps is not None, "training_steps must be provided" - T_0 = training_steps // num_cycles + elif scheduler_name == "cosineannealing": + return lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps, eta_min=lr_end) + elif scheduler_name == "cosineannealingwarmrestarts": return lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, T_0=T_0, eta_min=lr_end + optimizer, T_0=steps // num_cycles, eta_min=lr_end ) else: raise ValueError(f"Unsupported scheduler: {scheduler_name}") diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index bb31ce23..bb2766d5 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -218,6 +218,7 @@ def _build_train_context( optimizer = Adam(sae.parameters(), lr=sae.cfg.lr) scheduler = get_scheduler( sae.cfg.lr_scheduler_name, + lr=sae.cfg.lr, optimizer=optimizer, warm_up_steps=sae.cfg.lr_warm_up_steps, training_steps=total_training_steps, diff --git a/tests/unit/test_optim.py b/tests/unit/test_optim.py index d24e6516..d7a44eb4 100644 --- a/tests/unit/test_optim.py +++ b/tests/unit/test_optim.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import torch from torch.optim import Adam @@ -9,10 +11,12 @@ from sae_training.optim import get_scheduler +LR = 0.1 + @pytest.fixture def optimizer(): - return Adam([torch.tensor(1.0)], lr=0.1) + return Adam([torch.tensor(1.0)], lr=LR) def step_times(num: int, optimizer: Adam, scheduler: LRScheduler): @@ -25,37 +29,23 @@ def step(optimizer: Adam, scheduler: LRScheduler): scheduler.step() -@pytest.mark.parametrize( - "scheduler_name", - [ - "linearwarmupdecay", - "cosineannealing", - "cosineannealingwarmup", - "cosineannealingwarmrestarts", - ], -) -def test_get_scheduler_requires_training_steps(scheduler_name: str, optimizer: Adam): - with pytest.raises(AssertionError, match="training_steps must be provided"): - get_scheduler(scheduler_name, optimizer, 10) - - def test_get_scheduler_errors_on_uknown_scheduler(optimizer: Adam): with pytest.raises(ValueError, match="Unsupported scheduler: unknown"): - get_scheduler("unknown", optimizer) + get_scheduler("unknown", optimizer, lr=LR, training_steps=10) def test_get_scheduler_constant(optimizer: Adam): - scheduler = get_scheduler("constant", optimizer) + scheduler = get_scheduler("constant", optimizer, lr=LR, training_steps=4) assert scheduler.get_last_lr() == [0.1] step_times(3, optimizer, scheduler) assert scheduler.get_last_lr() == [0.1] def test_get_scheduler_constantwithwarmup(optimizer: Adam): - scheduler = get_scheduler("constantwithwarmup", optimizer, warm_up_steps=2) - assert scheduler.get_last_lr() == [0.05] - step(optimizer, scheduler) - assert scheduler.get_last_lr() == [0.1] + scheduler = get_scheduler( + "constant", optimizer, lr=LR, warm_up_steps=2, training_steps=4 + ) + assert scheduler.get_last_lr() == [pytest.approx(0.05)] step(optimizer, scheduler) assert scheduler.get_last_lr() == [0.1] step_times(3, optimizer, scheduler) @@ -64,7 +54,7 @@ def test_get_scheduler_constantwithwarmup(optimizer: Adam): def test_get_scheduler_linearwarmupdecay(optimizer: Adam): scheduler = get_scheduler( - "linearwarmupdecay", optimizer, warm_up_steps=2, training_steps=6 + "constant", optimizer, lr=LR, warm_up_steps=2, decay_steps=4, training_steps=6 ) # first, ramp up for 2 steps assert scheduler.get_last_lr() == [0.05] @@ -81,29 +71,41 @@ def test_get_scheduler_linearwarmupdecay(optimizer: Adam): assert scheduler.get_last_lr() == [pytest.approx(0.025)] step(optimizer, scheduler) assert scheduler.get_last_lr() == [0.0] - # NOTE: the LR goes negative if you go beyond the training steps + + +def test_get_scheduler_errors_if_lr_end_is_0_and_decay_is_set(optimizer: Adam): + with pytest.raises(ValueError, match="Cannot have decay_steps with lr_end=0.0"): + get_scheduler( + "cosineannealing", + optimizer, + lr=LR, + lr_end=0.0, + decay_steps=2, + training_steps=6, + ) def test_get_scheduler_cosineannealing(optimizer: Adam): - scheduler = get_scheduler( - "cosineannealing", optimizer, training_steps=4, lr_end=0.05 + scheduler: Any = get_scheduler( + "cosineannealing", optimizer, lr=LR, training_steps=4, lr_end=0.05 ) - assert isinstance(scheduler, CosineAnnealingLR) - assert scheduler.T_max == 4 - assert scheduler.eta_min == 0.05 + assert len(scheduler._schedulers) == 1 + main_scheduler = scheduler._schedulers[0] + assert isinstance(main_scheduler, CosineAnnealingLR) + assert main_scheduler.T_max == 4 + assert main_scheduler.eta_min == 0.05 -def test_get_scheduler_cosineannealingwarmup(): - # NOTE: if the lr_end is not 0.0, this test will not pass. - # If eta_min = lr_end * lr, then the test will pass. - # We should be careful about the difference between our lr_end and eta_min. - lr_end = 0.0 - optimizer = Adam([torch.tensor(1.0)], lr=0.1) +def test_get_scheduler_cosineannealing_with_warmup_and_decay(): + lr_end = 0.01 + optimizer = Adam([torch.tensor(1.0)], lr=LR) scheduler = get_scheduler( - "cosineannealingwarmup", + "cosineannealing", optimizer, + lr=LR, warm_up_steps=2, - training_steps=6, + training_steps=8, + decay_steps=2, lr_end=lr_end, ) # first, ramp up for 2 steps @@ -113,7 +115,7 @@ def test_get_scheduler_cosineannealingwarmup(): step(optimizer, scheduler) # From here on, it should match CosineAnnealingLR - new_optimizer = Adam([torch.tensor(1.0)], lr=0.1) + new_optimizer = Adam([torch.tensor(1.0)], lr=LR) cos_scheduler = CosineAnnealingLR(new_optimizer, T_max=4, eta_min=lr_end) step(optimizer, scheduler) @@ -125,16 +127,29 @@ def test_get_scheduler_cosineannealingwarmup(): step(optimizer, scheduler) step(new_optimizer, cos_scheduler) assert scheduler.get_last_lr() == pytest.approx(cos_scheduler.get_last_lr()) + step(optimizer, scheduler) + step(new_optimizer, cos_scheduler) + assert scheduler.get_last_lr() == pytest.approx(cos_scheduler.get_last_lr()) + assert scheduler.get_last_lr() == [lr_end] + + # now, decay to 0 in 2 steps + step(optimizer, scheduler) + assert scheduler.get_last_lr() == [pytest.approx(0.005)] + step(optimizer, scheduler) + assert scheduler.get_last_lr() == [pytest.approx(0.0)] def test_get_scheduler_cosineannealingwarmrestarts(optimizer: Adam): - scheduler = get_scheduler( + scheduler: Any = get_scheduler( "cosineannealingwarmrestarts", optimizer, + lr=LR, training_steps=8, lr_end=0.05, num_cycles=2, ) - assert isinstance(scheduler, CosineAnnealingWarmRestarts) - assert scheduler.T_0 == 4 - assert scheduler.eta_min == 0.05 + assert len(scheduler._schedulers) == 1 + main_scheduler = scheduler._schedulers[0] + assert isinstance(main_scheduler, CosineAnnealingWarmRestarts) + assert main_scheduler.T_0 == 4 + assert main_scheduler.eta_min == 0.05 diff --git a/tests/unit/test_train_sae_on_language_model.py b/tests/unit/test_train_sae_on_language_model.py index 096185cf..5fe681f9 100644 --- a/tests/unit/test_train_sae_on_language_model.py +++ b/tests/unit/test_train_sae_on_language_model.py @@ -46,7 +46,9 @@ def build_train_ctx( ), n_frac_active_tokens=n_frac_active_tokens, optimizer=optimizer, - scheduler=get_scheduler(None, optimizer=optimizer), + scheduler=get_scheduler( + "constant", lr=sae.cfg.lr, optimizer=optimizer, training_steps=1000 + ), )