Skip to content

Commit

Permalink
Add train loop support for looped PP schedules
Browse files Browse the repository at this point in the history
- refactor some per-model logic into helper functions

ghstack-source-id: a7768287ed2d31272b07ac9f3601b6e23e90c710
Pull Request resolved: #358
  • Loading branch information
wconstab committed Jun 13, 2024
1 parent e858ab4 commit f229b99
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 127 deletions.
21 changes: 20 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class OverrideDefinitions:
test_name: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4
model_flavor: str = "debugmodel"

def __repr__(self):
return self.test_descr
Expand Down Expand Up @@ -225,6 +226,22 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=8,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped 1f1b test",
"pp_looped_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
model_flavor="debugmodel_8_layers",
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -252,10 +269,11 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
# run_test supports sequence of tests.
test_name = test_flavor.test_name
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

if test_flavor.requires_seed_checkpoint:
cmd = f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
cmd = f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg} {model_flavor_arg}"
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
Expand All @@ -265,6 +283,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
cmd += " " + dump_folder_arg
cmd += " " + model_flavor_arg
if override_arg:
cmd += " " + " ".join(override_arg)
logger.info(
Expand Down
63 changes: 57 additions & 6 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def checkpoint_mp(recv, send):
class CheckpointManager:
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
model_parts: List[nn.Module],
optimizers: List[torch.optim.Optimizer],
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
Expand All @@ -137,13 +137,61 @@ def __init__(

if not self.enable_checkpoint:
return
"""
Tricky changes for Pipeline Parallelism
1. even for simple PP schedules, we introduce a problem where there is a separate optimizer on stage1's rank
vs stage0's rank. When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.
--> no fix yet
2. with looped schedules, we have multiple logical models per rank. This complicates both model state and
optimizer state handling in dcp.
a) Model states
- ideally, we support resharding the model. so we want to collapse the states back into one logical state-dict
- this means we merge the state-dicts from each model_part into one when saving and loading
b) Optimizer states
- if we create one optimizer object per model_part, we add a similar but orthogonal problem to (1), namely,
we have two optimziers on this local rank, and if we were to save them to the same "optim" key they would collide.
However, unlike (1), we have control over them both in one place, so we could save them under separate keys to avoid
the collision. Since this doesn't solve (1), it is only a temporary workaround. Also, if we go this route,
it would not be possible to load the optimizer states into a different parallelism configuration (resharding).
- if we enable one optimizer object to handle multiple model_parts, e.g. by wrapping model_parts in a ModuleList,
then we could correctly save the optimizer states for this PP rank. But we still have the bug in (1) preventing us from
reloading the optimizer states.
In any case, we won't be able to reload a PP checkpoint with optimizer states even without reshard, until (1) is fixed.
And the fix for (1) might change the option space for handling (2) as well.
So, for now I will only save the model weights for PP in this PR, while we figure out the full story for (1).
Note: haven't thought much about lr_scheduler's states.
"""
assert len(model_parts) == len(
optimizers
), "Must pass one optimizer per model part"
assert len(model_parts) == len(
lr_schedulers
), "Must pass one lr_scheduler per model part"

self.states = states

"""plan
for save-
model: merge the state-dicts into one in __init__, then save/load model will 'just work',
and model portion would be 'reshardable'
optim: store each one in a separate key for now,
make a note/post explaining the issues and possible long term plan
"""
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"model": ModelWrapper(model_parts),
"optimizer": OptimizerWrapper(model_parts, optimizers),
# TODO(whc) flatten lr_schedulers using a wrapper and somehow handle resharding?
# or store one per key and explicitly dont support resharding?
# "lr_scheduler": lr_scheduler,
"dataloader": dataloader,
}
)
Expand Down Expand Up @@ -218,6 +266,9 @@ def _save_last_step(self, curr_step: int) -> None:
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.

# TODO(whc) if we have multiple model parts on this rank, should we merge all their keys into a flat state dict
# or keep them as separate state dicts under named keys like _models_0 and _models_1?
self.states = self.states["model"].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,15 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe"],
choices=["1f1b", "gpipe", "interleaved_1f1b"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules are not yet supported in torchtitan.""",
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
Expand Down
29 changes: 21 additions & 8 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,24 @@ def linear_warmup_linear_decay(current_step: int) -> float:
return curr_adjustment


def get_lr_scheduler(optimizer, job_config: JobConfig):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler

class SchedulersContainer:
def __init__(self, schedulers):
self.schedulers = schedulers

def step(self):
for schedulers in self.schedulers:
schedulers.step()

return SchedulersContainer(
[_get_lr_scheduler(optimizer) for optimizer in optimizers]
)
3 changes: 3 additions & 0 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16, rope_theta=500000),
"debugmodel_8_layers": ModelArgs(
dim=256, n_layers=8, n_heads=16, rope_theta=500000
),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down
Loading

0 comments on commit f229b99

Please sign in to comment.