Skip to content

Commit

Permalink
Add train loop support for looped PP schedules
Browse files Browse the repository at this point in the history
- refactor some per-model logic into helper functions

ghstack-source-id: a2376627e2864deeb9e4fbf49cecd0990bc434ea
Pull Request resolved: pytorch#358
  • Loading branch information
wconstab committed Jun 21, 2024
1 parent 0016d3c commit 9b23dbe
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 72 deletions.
15 changes: 15 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ def build_test_list():
requires_seed_checkpoint=True,
ngpu=8,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped 1f1b test",
"pp_looped_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
33 changes: 33 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,38 @@ def __init__(

if not self.enable_checkpoint:
return
"""
Note: Pipeline Parallelism and Virtual Stages
1. even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerWrapper.
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
assert len(model_parts) == len(
optimizers
), "Must pass one optimizer per model part"
assert len(model_parts) == len(
lr_schedulers
), "Must pass one lr_scheduler per model part"

assert len(model_parts) == len(
optimizers
Expand All @@ -146,6 +178,7 @@ def __init__(
), "Must pass one lr_scheduler per model part"

self.states = states

self.states.update(
{
"model": ModelWrapper(model_parts),
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,15 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe"],
choices=["1f1b", "gpipe", "interleaved_1f1b"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules are not yet supported in torchtitan.""",
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
Expand Down
165 changes: 99 additions & 66 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpointing, etc.

import copy
from collections import defaultdict
from typing import Dict, Tuple

Expand All @@ -31,6 +32,7 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank

# for selective AC
no_recompute_list = {
Expand Down Expand Up @@ -175,7 +177,12 @@ def _mixed_precision_dtype(


def pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
whole_model,
world_mesh,
parallel_dims,
job_config: JobConfig,
device,
model_config: Dict,
):
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand All @@ -191,67 +198,85 @@ def pipeline_llama_manual(
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
stage_idx = pp_rank

splits = job_config.experimental.pipeline_parallel_split_points
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < pp_size - 1 else None
if pp_rank > 0:
model.tok_embeddings = None

drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if f"layers.{name}" == start_layer:
drop_layers = False
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]

if pp_rank < pp_size - 1:
model.norm = None
model.output = None

logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}")

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
output_layer_shape = (
batch_size,
job_config.training.seq_len,
model_config.vocab_size,
)
if pp_rank == 0:
# first layer
(input,) = _llama_trace_input(job_config, model_config, device=device)
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

if pp_rank == pp_size - 1:
# last layer
output = torch.rand(output_layer_shape, dtype=torch.float32, device=device)
else:
# earlier layers (assume all end in a transformer layer)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None

drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if f"layers.{name}" == start_layer:
drop_layers = False
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]

if not is_last:
model.norm = None
model.output = None

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
output_layer_shape = (
batch_size,
job_config.training.seq_len,
model_config.vocab_size,
)
if is_first:
(input,) = _llama_trace_input(job_config, model_config, device=device)
else:
# later layers (assume all start w/ a transformer layer)
input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

model.to_empty(device=device)
stage = PipelineStage(
model,
pp_rank,
pp_size,
device,
input_args=input.chunk(microbatches)[0],
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
)
return ((stage,), (model,))
if is_last:
output = torch.rand(output_layer_shape, dtype=torch.float32, device=device)
else:
# earlier layers (assume all end in a transformer layer)
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

model.to_empty(device=device)
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
input_args=input.chunk(microbatches)[0],
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
)
return stage, model

num_stages = len(splits) + 1
stage_idx = pp_rank

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}: model chunk \n{model_chunk}"
)
stages.append(stage)
models.append(model_chunk)
return stages, models


def pipeline_llama_tracer(
Expand All @@ -272,6 +297,7 @@ def pipeline_llama_tracer(

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
Expand All @@ -281,18 +307,25 @@ def pipeline_llama_tracer(
layer_name: SplitPoint.BEGINNING
for layer_name in job_config.experimental.pipeline_parallel_split_points
}
num_stages = len(split_spec) + 1
pipe = pipeline(
model,
mb_args=(input.chunk(microbatches)[0],),
split_spec=split_spec,
)
model = pipe.get_stage_module(stage_idx)
stage = pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
return ((stage,), (model,))

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
models.append(pipe.get_stage_module(stage_idx))
stages.append(
pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
)
return (stages, models)


def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
Expand Down
38 changes: 34 additions & 4 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe
from typing import Tuple

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
from torchtitan.logging_utils import logger


def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
assert len(stages) == 1, "Looped schedules not yet supported"
stage = stages[0]

looped_schedule = False
if job_config.experimental.pipeline_parallel_schedule == "1f1b":
schedule_class = Schedule1F1B
elif job_config.experimental.pipeline_parallel_schedule == "gpipe":
schedule_class = ScheduleGPipe
elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b":
schedule_class = ScheduleInterleaved1F1B
looped_schedule = True
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"
Expand All @@ -26,7 +35,28 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
n_microbatches = job_config.experimental.pipeline_parallel_degree

return schedule_class(
stage,
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
)


# TODO(whc) should this be a utility inside torch.pipelining?
def stage_ids_this_rank(
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
) -> Tuple[int]:
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
assert (
num_stages % pp_size == 0
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
stages_per_rank = num_stages // pp_size
if style == "loop":
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
)
return stage_v_pairs[pp_rank]

0 comments on commit 9b23dbe

Please sign in to comment.