Skip to content

Commit

Permalink
Add train loop support for looped PP schedules
Browse files Browse the repository at this point in the history
- refactor some per-model logic into helper functions

ghstack-source-id: 4fcd38adafe9926799366c4c868219d47f7bc03c
Pull Request resolved: #358
  • Loading branch information
wconstab committed Jun 12, 2024
1 parent e858ab4 commit fc183db
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 64 deletions.
63 changes: 57 additions & 6 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def checkpoint_mp(recv, send):
class CheckpointManager:
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
model_parts: List[nn.Module],
optimizers: List[torch.optim.Optimizer],
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
Expand All @@ -137,13 +137,61 @@ def __init__(

if not self.enable_checkpoint:
return
"""
Tricky changes for Pipeline Parallelism
1. even for simple PP schedules, we introduce a problem where there is a separate optimizer on stage1's rank
vs stage0's rank. When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.
--> no fix yet
2. with looped schedules, we have multiple logical models per rank. This complicates both model state and
optimizer state handling in dcp.
a) Model states
- ideally, we support resharding the model. so we want to collapse the states back into one logical state-dict
- this means we merge the state-dicts from each model_part into one when saving and loading
b) Optimizer states
- if we create one optimizer object per model_part, we add a similar but orthogonal problem to (1), namely,
we have two optimziers on this local rank, and if we were to save them to the same "optim" key they would collide.
However, unlike (1), we have control over them both in one place, so we could save them under separate keys to avoid
the collision. Since this doesn't solve (1), it is only a temporary workaround. Also, if we go this route,
it would not be possible to load the optimizer states into a different parallelism configuration (resharding).
- if we enable one optimizer object to handle multiple model_parts, e.g. by wrapping model_parts in a ModuleList,
then we could correctly save the optimizer states for this PP rank. But we still have the bug in (1) preventing us from
reloading the optimizer states.
In any case, we won't be able to reload a PP checkpoint with optimizer states even without reshard, until (1) is fixed.
And the fix for (1) might change the option space for handling (2) as well.
So, for now I will only save the model weights for PP in this PR, while we figure out the full story for (1).
Note: haven't thought much about lr_scheduler's states.
"""
assert len(model_parts) == len(
optimizers
), "Must pass one optimizer per model part"
assert len(model_parts) == len(
lr_schedulers
), "Must pass one lr_scheduler per model part"

self.states = states

"""plan
for save-
model: merge the state-dicts into one in __init__, then save/load model will 'just work',
and model portion would be 'reshardable'
optim: store each one in a separate key for now,
make a note/post explaining the issues and possible long term plan
"""
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"model": ModelWrapper(model_parts),
"optimizer": OptimizerWrapper(model_parts, optimizers),
# TODO(whc) flatten lr_schedulers using a wrapper and somehow handle resharding?
# or store one per key and explicitly dont support resharding?
# "lr_scheduler": lr_scheduler,
"dataloader": dataloader,
}
)
Expand Down Expand Up @@ -218,6 +266,9 @@ def _save_last_step(self, curr_step: int) -> None:
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.

# TODO(whc) if we have multiple model parts on this rank, should we merge all their keys into a flat state dict
# or keep them as separate state dicts under named keys like _models_0 and _models_1?
self.states = self.states["model"].state_dict()

# For now, we will manually pop the freqs_cis buffer, as we made this permanent
Expand Down
29 changes: 21 additions & 8 deletions torchtitan/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,24 @@ def linear_warmup_linear_decay(current_step: int) -> float:
return curr_adjustment


def get_lr_scheduler(optimizer, job_config: JobConfig):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = int(job_config.training.warmup_steps)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler

class SchedulersContainer:
def __init__(self, schedulers):
self.schedulers = schedulers

def step(self):
for schedulers in self.schedulers:
schedulers.step()

return SchedulersContainer(
[_get_lr_scheduler(optimizer) for optimizer in optimizers]
)
14 changes: 11 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def pipeline_llama_manual(
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
)
return (stage, model)
return [stage], [model]


def pipeline_llama_tracer(
Expand Down Expand Up @@ -281,10 +281,18 @@ def pipeline_llama_tracer(
device=device,
group=pp_mesh.get_group(),
)
return (stage, model)
return [stage], [model]


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def parallelize_llama(model_parts, world_mesh, parallel_dims, job_config: JobConfig):
"""Apply SPMD parallelisms and activation checkpointing to each model in model_parts"""
return [
_parallelize_llama(m, world_mesh, parallel_dims, job_config)
for m in model_parts
]


def _parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply SPMD parallelisms and activation checkpointing to the model.
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from torchtitan.logging_utils import logger


def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
assert len(stages) == 1, "Only simple schedules are supported currently"
stage = stages[0]
if job_config.experimental.pipeline_parallel_schedule == "1f1b":
schedule_class = Schedule1F1B
elif job_config.experimental.pipeline_parallel_schedule == "gpipe":
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from dataclasses import dataclass
from datetime import timedelta
from typing import Union
from typing import List, Union

import torch
import torch.distributed._functional_collectives as funcol
Expand All @@ -17,6 +17,11 @@
from torchtitan.parallelisms import ParallelDims


def move_to_empty(model_parts: List[torch.nn.Module], device: torch.device):
for model in model_parts:
model.to_empty(device="cuda")


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)
Expand Down
118 changes: 73 additions & 45 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_scheduler
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import (
Expand All @@ -48,6 +48,7 @@
get_num_params,
get_peak_flops,
init_distributed,
move_to_empty,
NoColor,
set_pg_timeouts,
)
Expand Down Expand Up @@ -90,29 +91,47 @@ def load_state_dict(self, state_dict) -> None:
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)


def build_optimizer(model, job_config: JobConfig):
# build optimizer
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused

# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")
def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""

def _build_optimizer(model):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused

# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

return optimizer

class OptimizersContainer:
def __init__(self, optimizers):
self.optimizers = optimizers

return optimizer
def step(self):
for optimizer in self.optimizers:
optimizer.step()

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()

return OptimizersContainer([_build_optimizer(model) for model in model_parts])


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
Expand Down Expand Up @@ -195,16 +214,21 @@ def loss_fn(pred, labels):
logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
model = model_cls.from_model_args(model_config)
whole_model = model_cls.from_model_args(model_config)

# In 1D/2D cases or PP with simple schedules, model_parts is just one item
# for PP with looped schedules, each item is one stage-model-chunk
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
model_parts = [whole_model]

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(model, job_config)

# log model size
model_param_count = get_num_params(model)
model_param_count = get_num_params(whole_model)
num_flop_per_token = get_num_flop_per_token(
get_num_params(model, exclude_embedding=True),
get_num_params(whole_model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
Expand All @@ -219,26 +243,28 @@ def loss_fn(pred, labels):
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name)

if parallel_dims.pp_enabled:
stage, model = models_pipelining_fns[model_name](
model, world_mesh, parallel_dims, job_config, device, model_config
stages, model_parts = models_pipelining_fns[model_name](
whole_model, world_mesh, parallel_dims, job_config, device, model_config
)

# apply PT-D DP/TP parallelisms and activation checkpointing
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
model_parts = models_parallelize_fns[model_name](
model_parts, world_mesh, parallel_dims, job_config
)

init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
move_to_empty(model_parts, device=init_device)

if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn)
pp_schedule = build_pipeline_schedule(
job_config, parallel_dims, stages, loss_fn
)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()
whole_model.init_weights()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -248,8 +274,8 @@ def loss_fn(pred, labels):
)

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)

metric_logger = build_metric_logger(
job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims)
Expand All @@ -258,12 +284,13 @@ def loss_fn(pred, labels):
train_state = TrainState()

# train loop
model.train()
for model in model_parts:
model.train()

checkpoint = CheckpointManager(
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
model_parts=model_parts,
optimizers=optimizers.optimizers,
lr_schedulers=lr_schedulers.schedulers,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
Expand Down Expand Up @@ -325,7 +352,7 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizer.zero_grad()
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
Expand Down Expand Up @@ -354,14 +381,15 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)

# optimizer step
checkpoint.wait_for_staging()
optimizer.step()
scheduler.step()
optimizers.step()
lr_schedulers.step()

losses_since_last_log.append(loss)

Expand Down

0 comments on commit fc183db

Please sign in to comment.