diff --git a/test_runner.py b/test_runner.py index ab120ea57..ad5b41317 100755 --- a/test_runner.py +++ b/test_runner.py @@ -123,12 +123,14 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_split_mode tracer", "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer ], ], "PP tracer frontend test", "pp_tracer", requires_seed_checkpoint=True, + ngpu=2, ), OverrideDefinitions( [ diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 876f64b63..2d8e21507 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -261,23 +261,23 @@ def pipeline_llama_tracer( pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() + microbatches = ( + job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp + ) + (input,) = _llama_trace_input(job_config, model_config, device=device) stage_idx = pp_rank - layers_per_rank = len(model.layers) // parallel_dims.pp split_spec = { - f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING - for i in range(1, parallel_dims.pp) + layer_name: SplitPoint.BEGINNING + for layer_name in job_config.experimental.pipeline_parallel_split_points } - pipe = pipeline( model, - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp, - example_args=_llama_trace_input(job_config, model_config), + mb_args=(input.chunk(microbatches)[0],), split_spec=split_spec, ) model = pipe.get_stage_module(stage_idx) - stage = PipelineStage( - pipe, - stage_index=stage_idx, + stage = pipe.build_stage( + stage_idx, device=device, group=pp_mesh.get_group(), )