Skip to content

Commit

Permalink
Fix 1D PP tracer test
Browse files Browse the repository at this point in the history
forgot to enable tracer for tracer test in the last PR

ghstack-source-id: 1cb137911f88daa47b57757346dad55aa736429e
Pull Request resolved: #362
  • Loading branch information
kwen2501 authored and wconstab committed Jun 11, 2024
1 parent 3bb7bf3 commit e858ab4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 2 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_split_mode tracer",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
Expand Down
18 changes: 9 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,23 @@ def pipeline_llama_tracer(

pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
stage_idx = pp_rank
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
layer_name: SplitPoint.BEGINNING
for layer_name in job_config.experimental.pipeline_parallel_split_points
}

pipe = pipeline(
model,
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp,
example_args=_llama_trace_input(job_config, model_config),
mb_args=(input.chunk(microbatches)[0],),
split_spec=split_spec,
)
model = pipe.get_stage_module(stage_idx)
stage = PipelineStage(
pipe,
stage_index=stage_idx,
stage = pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
Expand Down

0 comments on commit e858ab4

Please sign in to comment.