-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Renamed parallel styles for transformer block weights #448
Changes from 6 commits
6e54ae5
0a5c74f
2ec3a54
8dd53c3
a99b77d
73265be
06e30d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -316,10 +316,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): | |
""" | ||
|
||
tp_mesh = world_mesh["tp"] | ||
# Parallel styles for transformer block linear weights may be different for | ||
# float8 linears | ||
( | ||
row_parallel_strategy, | ||
col_parallel_strategy, | ||
prepare_module_input, | ||
rowwise_parallel_weight, | ||
colwise_parallel_weight, | ||
prepare_module_weight_input, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. module input to the weight 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
but does the general idea make sense that we want to somehow convey that these are meant for linears that may use fp8 if enabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that make sense, I feel other renamings looks good, for this in particular, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, I feel like |
||
) = get_tp_parallel_strategy(job_config) | ||
loss_parallel = parallel_dims.loss_parallel_enabled | ||
|
||
|
@@ -336,7 +338,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): | |
output_layouts=Shard(1), | ||
), | ||
"norm": SequenceParallel(), | ||
"output": col_parallel_strategy( | ||
"output": colwise_parallel_weight( | ||
input_layouts=Shard(1), | ||
output_layouts=Shard(-1) if loss_parallel else Replicate(), | ||
use_local_output=not loss_parallel, | ||
|
@@ -351,22 +353,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): | |
for layer_id, transformer_block in model.layers.items(): | ||
layer_plan = { | ||
"attention_norm": SequenceParallel(), | ||
"attention": prepare_module_input( | ||
"attention": prepare_module_weight_input( | ||
input_layouts=(Shard(1), None), | ||
desired_input_layouts=(Replicate(), None), | ||
), | ||
"attention.wq": col_parallel_strategy(), | ||
"attention.wk": col_parallel_strategy(), | ||
"attention.wv": col_parallel_strategy(), | ||
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)), | ||
"attention.wq": colwise_parallel_weight(), | ||
"attention.wk": colwise_parallel_weight(), | ||
"attention.wv": colwise_parallel_weight(), | ||
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)), | ||
"ffn_norm": SequenceParallel(), | ||
"feed_forward": prepare_module_input( | ||
"feed_forward": prepare_module_weight_input( | ||
input_layouts=(Shard(1),), | ||
desired_input_layouts=(Replicate(),), | ||
), | ||
"feed_forward.w1": col_parallel_strategy(), | ||
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), | ||
"feed_forward.w3": col_parallel_strategy(), | ||
"feed_forward.w1": colwise_parallel_weight(), | ||
"feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)), | ||
"feed_forward.w3": colwise_parallel_weight(), | ||
} | ||
|
||
# Adjust attention module to use the local number of heads | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are only used for transformer block weights, so we can rename this to differentiate from the
ColwiseParallel
etc. for the other parameters.