Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renamed parallel styles for transformer block weights #448

Merged
merged 7 commits into from
Jul 11, 2024
28 changes: 15 additions & 13 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only used for transformer block weights, so we can rename this to differentiate from the ColwiseParallel etc. for the other parameters.

# float8 linears
(
row_parallel_strategy,
col_parallel_strategy,
prepare_module_input,
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_weight_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of module_weight_input? I guess it should still be module_input as it's activation resharding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module input to the weight 😅

Copy link
Contributor Author

@awgu awgu Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_linear_module_input? not sure because there are also linears that do not use fp8

but does the general idea make sense that we want to somehow convey that these are meant for linears that may use fp8 if enabled?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that make sense, I feel other renamings looks good, for this in particular, prepare_linear_module_input sounds better!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I feel like prepare_linear_module_input is a bit confusing still 😆
let me just leave it as prepare_module_input for now

) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

Expand All @@ -336,7 +338,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
"output": colwise_parallel_weight(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
Expand All @@ -351,22 +353,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
"attention": prepare_module_weight_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention.wq": colwise_parallel_weight(),
"attention.wk": colwise_parallel_weight(),
"attention.wv": colwise_parallel_weight(),
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
"feed_forward": prepare_module_weight_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"feed_forward.w1": colwise_parallel_weight(),
"feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel_weight(),
}

# Adjust attention module to use the local number of heads
Expand Down
Loading