Skip to content

Commit

Permalink
adding lr_decay_steps and refactoring get_scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Apr 3, 2024
1 parent 1ce44d7 commit fd5448c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 100 deletions.
3 changes: 2 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
"constant" # constant, cosineannealing, cosineannealingwarmrestarts
)
lr_warm_up_steps: int = 500
lr_decay_steps: int = 0
train_batch_size: int = 4096

# Resampling protocol args
Expand Down
124 changes: 67 additions & 57 deletions sae_training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
Took the LR scheduler from my previous work: /~https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
"""

import math
from typing import Optional

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

Expand All @@ -14,75 +11,88 @@
# Cosine Annealing with Warmup
# Cosine Annealing with Warmup / Restarts
def get_scheduler(
scheduler_name: Optional[str],
scheduler_name: str,
optimizer: optim.Optimizer,
training_steps: int,
lr: float,
warm_up_steps: int = 0,
training_steps: int | None = None,
decay_steps: int = 0,
num_cycles: int = 1,
lr_end: float = 0.0,
):
) -> lr_scheduler.LRScheduler:
"""
Loosely based on this, seemed simpler write this than import
transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
Args:
scheduler_name (Optional[str]): Name of the scheduler to use. If None, returns a constant scheduler
scheduler_name (str): Name of the scheduler to use, one of "constant", "cosineannealing", "cosineannealingwarmrestarts"
optimizer (optim.Optimizer): Optimizer to use
**kwargs: Additional arguments to pass to the scheduler including warm_up_steps,
training_steps, num_cycles, lr_end.
training_steps (int): Total number of training steps
warm_up_steps (int, optional): Number of linear warm up steps. Defaults to 0.
decay_steps (int, optional): Number of linear decay steps to 0. Defaults to 0.
num_cycles (int, optional): Number of cycles for cosine annealing with warm restarts. Defaults to 1.
lr_end (float, optional): Final learning rate multiplier before decay. Defaults to 0.0.
"""
base_scheduler_steps = training_steps - warm_up_steps - decay_steps
norm_scheduler_name = scheduler_name.lower()
main_scheduler = _get_main_scheduler(
norm_scheduler_name,
optimizer,
steps=base_scheduler_steps,
lr_end=lr_end,
num_cycles=num_cycles,
)
if norm_scheduler_name == "constant":
# constant scheduler ignores lr_end, so decay needs to start at lr
lr_end = lr
schedulers: list[lr_scheduler.LRScheduler] = []
milestones: list[int] = []
if warm_up_steps > 0:
schedulers.append(
lr_scheduler.LinearLR(
optimizer,
start_factor=1 / warm_up_steps,
end_factor=1.0,
total_iters=warm_up_steps - 1,
),
)
milestones.append(warm_up_steps)
schedulers.append(main_scheduler)
if decay_steps > 0:
if lr_end == 0.0:
raise ValueError(
"Cannot have decay_steps with lr_end=0.0, this would decay from 0 to 0 and be a waste."
)
schedulers.append(
lr_scheduler.LinearLR(
optimizer,
start_factor=lr_end / lr,
end_factor=0.0,
total_iters=decay_steps,
),
)
milestones.append(training_steps - decay_steps)
return lr_scheduler.SequentialLR(
schedulers=schedulers,
optimizer=optimizer,
milestones=milestones,
)

def get_warmup_lambda(warm_up_steps: int, training_steps: int):

def lr_lambda(steps: int):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
return (training_steps - steps) / (training_steps - warm_up_steps)

return lr_lambda

# heavily derived from hugging face although copilot helped.
def get_warmup_cosine_lambda(
warm_up_steps: int, training_steps: int, lr_end: float
):

def lr_lambda(steps: int):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
progress = (steps - warm_up_steps) / (training_steps - warm_up_steps)
return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress))

return lr_lambda

if scheduler_name is None or scheduler_name.lower() == "constant":
def _get_main_scheduler(
scheduler_name: str,
optimizer: optim.Optimizer,
steps: int,
lr_end: float,
num_cycles: int,
) -> lr_scheduler.LRScheduler:
if scheduler_name == "constant":
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0)
elif scheduler_name.lower() == "constantwithwarmup":
return lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps),
)
elif scheduler_name.lower() == "linearwarmupdecay":
assert training_steps is not None, "training_steps must be provided"
lr_lambda = get_warmup_lambda(warm_up_steps, training_steps)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "cosineannealing":
assert training_steps is not None, "training_steps must be provided"
return lr_scheduler.CosineAnnealingLR(
optimizer, T_max=training_steps, eta_min=lr_end
)
elif scheduler_name.lower() == "cosineannealingwarmup":
assert training_steps is not None, "training_steps must be provided"
lr_lambda = get_warmup_cosine_lambda(
warm_up_steps, training_steps, lr_end=lr_end
)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "cosineannealingwarmrestarts":
assert training_steps is not None, "training_steps must be provided"
T_0 = training_steps // num_cycles
elif scheduler_name == "cosineannealing":
return lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps, eta_min=lr_end)
elif scheduler_name == "cosineannealingwarmrestarts":
return lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=T_0, eta_min=lr_end
optimizer, T_0=steps // num_cycles, eta_min=lr_end
)
else:
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
1 change: 1 addition & 0 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _build_train_context(
optimizer = Adam(sae.parameters(), lr=sae.cfg.lr)
scheduler = get_scheduler(
sae.cfg.lr_scheduler_name,
lr=sae.cfg.lr,
optimizer=optimizer,
warm_up_steps=sae.cfg.lr_warm_up_steps,
training_steps=total_training_steps,
Expand Down
97 changes: 56 additions & 41 deletions tests/unit/test_optim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import pytest
import torch
from torch.optim import Adam
Expand All @@ -9,10 +11,12 @@

from sae_training.optim import get_scheduler

LR = 0.1


@pytest.fixture
def optimizer():
return Adam([torch.tensor(1.0)], lr=0.1)
return Adam([torch.tensor(1.0)], lr=LR)


def step_times(num: int, optimizer: Adam, scheduler: LRScheduler):
Expand All @@ -25,37 +29,23 @@ def step(optimizer: Adam, scheduler: LRScheduler):
scheduler.step()


@pytest.mark.parametrize(
"scheduler_name",
[
"linearwarmupdecay",
"cosineannealing",
"cosineannealingwarmup",
"cosineannealingwarmrestarts",
],
)
def test_get_scheduler_requires_training_steps(scheduler_name: str, optimizer: Adam):
with pytest.raises(AssertionError, match="training_steps must be provided"):
get_scheduler(scheduler_name, optimizer, 10)


def test_get_scheduler_errors_on_uknown_scheduler(optimizer: Adam):
with pytest.raises(ValueError, match="Unsupported scheduler: unknown"):
get_scheduler("unknown", optimizer)
get_scheduler("unknown", optimizer, lr=LR, training_steps=10)


def test_get_scheduler_constant(optimizer: Adam):
scheduler = get_scheduler("constant", optimizer)
scheduler = get_scheduler("constant", optimizer, lr=LR, training_steps=4)
assert scheduler.get_last_lr() == [0.1]
step_times(3, optimizer, scheduler)
assert scheduler.get_last_lr() == [0.1]


def test_get_scheduler_constantwithwarmup(optimizer: Adam):
scheduler = get_scheduler("constantwithwarmup", optimizer, warm_up_steps=2)
assert scheduler.get_last_lr() == [0.05]
step(optimizer, scheduler)
assert scheduler.get_last_lr() == [0.1]
scheduler = get_scheduler(
"constant", optimizer, lr=LR, warm_up_steps=2, training_steps=4
)
assert scheduler.get_last_lr() == [pytest.approx(0.05)]
step(optimizer, scheduler)
assert scheduler.get_last_lr() == [0.1]
step_times(3, optimizer, scheduler)
Expand All @@ -64,7 +54,7 @@ def test_get_scheduler_constantwithwarmup(optimizer: Adam):

def test_get_scheduler_linearwarmupdecay(optimizer: Adam):
scheduler = get_scheduler(
"linearwarmupdecay", optimizer, warm_up_steps=2, training_steps=6
"constant", optimizer, lr=LR, warm_up_steps=2, decay_steps=4, training_steps=6
)
# first, ramp up for 2 steps
assert scheduler.get_last_lr() == [0.05]
Expand All @@ -81,29 +71,41 @@ def test_get_scheduler_linearwarmupdecay(optimizer: Adam):
assert scheduler.get_last_lr() == [pytest.approx(0.025)]
step(optimizer, scheduler)
assert scheduler.get_last_lr() == [0.0]
# NOTE: the LR goes negative if you go beyond the training steps


def test_get_scheduler_errors_if_lr_end_is_0_and_decay_is_set(optimizer: Adam):
with pytest.raises(ValueError, match="Cannot have decay_steps with lr_end=0.0"):
get_scheduler(
"cosineannealing",
optimizer,
lr=LR,
lr_end=0.0,
decay_steps=2,
training_steps=6,
)


def test_get_scheduler_cosineannealing(optimizer: Adam):
scheduler = get_scheduler(
"cosineannealing", optimizer, training_steps=4, lr_end=0.05
scheduler: Any = get_scheduler(
"cosineannealing", optimizer, lr=LR, training_steps=4, lr_end=0.05
)
assert isinstance(scheduler, CosineAnnealingLR)
assert scheduler.T_max == 4
assert scheduler.eta_min == 0.05
assert len(scheduler._schedulers) == 1
main_scheduler = scheduler._schedulers[0]
assert isinstance(main_scheduler, CosineAnnealingLR)
assert main_scheduler.T_max == 4
assert main_scheduler.eta_min == 0.05


def test_get_scheduler_cosineannealingwarmup():
# NOTE: if the lr_end is not 0.0, this test will not pass.
# If eta_min = lr_end * lr, then the test will pass.
# We should be careful about the difference between our lr_end and eta_min.
lr_end = 0.0
optimizer = Adam([torch.tensor(1.0)], lr=0.1)
def test_get_scheduler_cosineannealing_with_warmup_and_decay():
lr_end = 0.01
optimizer = Adam([torch.tensor(1.0)], lr=LR)
scheduler = get_scheduler(
"cosineannealingwarmup",
"cosineannealing",
optimizer,
lr=LR,
warm_up_steps=2,
training_steps=6,
training_steps=8,
decay_steps=2,
lr_end=lr_end,
)
# first, ramp up for 2 steps
Expand All @@ -113,7 +115,7 @@ def test_get_scheduler_cosineannealingwarmup():
step(optimizer, scheduler)

# From here on, it should match CosineAnnealingLR
new_optimizer = Adam([torch.tensor(1.0)], lr=0.1)
new_optimizer = Adam([torch.tensor(1.0)], lr=LR)
cos_scheduler = CosineAnnealingLR(new_optimizer, T_max=4, eta_min=lr_end)

step(optimizer, scheduler)
Expand All @@ -125,16 +127,29 @@ def test_get_scheduler_cosineannealingwarmup():
step(optimizer, scheduler)
step(new_optimizer, cos_scheduler)
assert scheduler.get_last_lr() == pytest.approx(cos_scheduler.get_last_lr())
step(optimizer, scheduler)
step(new_optimizer, cos_scheduler)
assert scheduler.get_last_lr() == pytest.approx(cos_scheduler.get_last_lr())
assert scheduler.get_last_lr() == [lr_end]

# now, decay to 0 in 2 steps
step(optimizer, scheduler)
assert scheduler.get_last_lr() == [pytest.approx(0.005)]
step(optimizer, scheduler)
assert scheduler.get_last_lr() == [pytest.approx(0.0)]


def test_get_scheduler_cosineannealingwarmrestarts(optimizer: Adam):
scheduler = get_scheduler(
scheduler: Any = get_scheduler(
"cosineannealingwarmrestarts",
optimizer,
lr=LR,
training_steps=8,
lr_end=0.05,
num_cycles=2,
)
assert isinstance(scheduler, CosineAnnealingWarmRestarts)
assert scheduler.T_0 == 4
assert scheduler.eta_min == 0.05
assert len(scheduler._schedulers) == 1
main_scheduler = scheduler._schedulers[0]
assert isinstance(main_scheduler, CosineAnnealingWarmRestarts)
assert main_scheduler.T_0 == 4
assert main_scheduler.eta_min == 0.05
4 changes: 3 additions & 1 deletion tests/unit/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def build_train_ctx(
),
n_frac_active_tokens=n_frac_active_tokens,
optimizer=optimizer,
scheduler=get_scheduler(None, optimizer=optimizer),
scheduler=get_scheduler(
"constant", lr=sae.cfg.lr, optimizer=optimizer, training_steps=1000
),
)


Expand Down

0 comments on commit fd5448c

Please sign in to comment.