-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Used per-parameter FSDP #165
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,19 +8,13 @@ | |
from typing import Tuple | ||
|
||
import torch | ||
from torch.distributed._tensor import Replicate, Shard | ||
|
||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||
from torch.distributed._tensor import Replicate, Shard | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper as ptd_checkpoint_wrapper, | ||
CheckpointImpl, | ||
) | ||
from torch.distributed.fsdp import ( | ||
BackwardPrefetch, | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
FullyShardedDataParallel as FSDP, | ||
MixedPrecision, | ||
ShardingStrategy, | ||
) | ||
from torch.distributed.fsdp.wrap import enable_wrap, wrap | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
parallelize_module, | ||
|
@@ -33,7 +27,6 @@ | |
|
||
from torchtrain.config_manager import JobConfig | ||
from torchtrain.logging_utils import logger | ||
from torchtrain.meta_init import meta_to_real_init_fn | ||
|
||
|
||
# for selective AC | ||
|
@@ -75,7 +68,6 @@ def selective_checkpointing_context_fn(): | |
preserve_rng_state=False, | ||
) | ||
elif config.mode == "full": | ||
# full AC | ||
return ptd_checkpoint_wrapper( | ||
module, | ||
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | ||
|
@@ -136,28 +128,23 @@ def get_tp_parallel_strategy( | |
|
||
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
""" | ||
Apply parallelisms to the model, including PTD parallelisms, and AC. | ||
Apply parallelisms and activation checkpointing to the model. | ||
|
||
NOTE: the model passed in preferrablably shoule be a meta device model, | ||
otherwise the model needs to be small enough on GPU or can fit into CPU. | ||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
""" | ||
# apply PTD parallelisms | ||
if parallel_dims.pp_enabled: | ||
raise NotImplementedError("PP not implemented yet.") | ||
|
||
# First we apply Tensor Parallelism if it's enabled | ||
if parallel_dims.tp_enabled: | ||
tp_mesh = world_mesh["tp"] | ||
tp_degree = job_config.training.tensor_parallel_degree | ||
|
||
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( | ||
job_config | ||
) | ||
|
||
# First: | ||
# 1. parallelize the first embedding and the last linear proj layer | ||
# 2. parallelize the root norm layer by sequence dim | ||
# 3. shard the first layer of transformer block | ||
# 1. Parallelize the first embedding and the last linear proj layer | ||
# 2. Parallelize the root norm layer over the sequence dim | ||
# 3. Shard the first transformer block's inputs | ||
model = parallelize_module( | ||
model, | ||
tp_mesh, | ||
|
@@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
), | ||
"output": col_parallel_strategy( | ||
input_layouts=Shard(0), | ||
output_layouts=Shard(-1) | ||
if parallel_dims.loss_parallel_enabled | ||
else Replicate(), | ||
output_layouts=( | ||
Shard(-1) | ||
if parallel_dims.loss_parallel_enabled | ||
else Replicate() | ||
), | ||
use_local_output=not parallel_dims.loss_parallel_enabled, | ||
), | ||
"norm": SequenceParallel(sequence_dim=0), | ||
|
@@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
}, | ||
) | ||
|
||
# apply tensor + sequence parallelism to every transformer block | ||
# Apply tensor + sequence parallelism to every transformer block | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
layer_plan = { | ||
"attention": PrepareModuleInput( | ||
|
@@ -203,62 +192,47 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
"ffn_norm": SequenceParallel(sequence_dim=0), | ||
} | ||
|
||
# adjust num_heads in attention layer to local heads | ||
# Adjust attention module to use the local number of heads | ||
attn_layer = transformer_block.attention | ||
attn_layer.n_heads = attn_layer.n_heads // tp_degree | ||
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree | ||
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() | ||
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() | ||
|
||
parallelize_module( | ||
module=transformer_block, | ||
device_mesh=tp_mesh, | ||
parallelize_plan=layer_plan, | ||
) | ||
|
||
logger.info("Applied Sequence Parallelism to the model") | ||
logger.info("Applied Tensor Parallelism to the model") | ||
|
||
if parallel_dims.dp_enabled: | ||
dp_mesh = world_mesh["dp"] | ||
|
||
fsdp_config = { | ||
"mixed_precision": MixedPrecision( | ||
param_dtype=torch.bfloat16, | ||
# TODO: see whether we should expose a option to user | ||
reduce_dtype=torch.float32, | ||
), | ||
"sharding_strategy": ShardingStrategy.FULL_SHARD, | ||
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE, | ||
# When torch.compile is active, it requires us to set use_orig_params=True | ||
"use_orig_params": True, | ||
"device_mesh": dp_mesh, | ||
"param_init_fn": meta_to_real_init_fn, | ||
} | ||
|
||
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
# TODO: Expose `reduce_dtype` as a config option. | ||
mp_policy = MixedPrecisionPolicy( | ||
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
ac_mode = job_config.activation_checkpoint.mode | ||
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
# apply AC to the transformer block | ||
if ac_mode in ("full", "selective"): | ||
# wrap the transformer block with checkpoint wrapper, using config settings | ||
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.activation_checkpoint | ||
) | ||
|
||
# Wraps each layer with FSDP | ||
model.layers[layer_id] = wrap(transformer_block) | ||
|
||
# wrap the rest layers with FSDP | ||
model = wrap(model) | ||
|
||
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
if job_config.activation_checkpoint.mode in ("full", "selective"): | ||
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.activation_checkpoint | ||
) | ||
# As an optimization, do not reshard after forward for the last | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am open to not including this 'trick' since it might be confusing. The idea is that we basically can |
||
# transformer block since FSDP would prefetch it immediately | ||
reshard_after_forward = layer_id < len(model.layers) - 1 | ||
fully_shard( | ||
transformer_block, | ||
**fsdp_config, | ||
reshard_after_forward=reshard_after_forward, | ||
) | ||
model.layers[layer_id] = transformer_block | ||
model = fully_shard(model, **fsdp_config) | ||
if ac_mode in ("full", "selective"): | ||
logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
logger.info("Applied FSDP to the model") | ||
else: | ||
meta_to_real_init_fn(model) | ||
model.cuda() | ||
|
||
# we have now moved from meta to device, | ||
# reset parameters for proper initialization | ||
model.reset_parameters() | ||
logger.info("Model fully initialized via reset_parameters") | ||
|
||
return model |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I have some confusions about the
reset_parameters
guideline, soreset_parameters
is an optional method in nn.Module, and it does not "recursively" call into the submodule'sreset_parameters
call when calling the parent module'sreset_parameters()
.This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?
And if that's the case, if user decide to not recursively loop submodules, one can simply define
reset_parameters
to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply callreset_parameters()
on their defined root module'sreset_parameters()
function and not worrying about the attention layerwq/wk/wv
be overriden by the builtin nn.Linear'sreset_parameter
call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?Not sure if you get my question haha, am I missing sth there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my understanding.
I agree with the approach you are mentioning
It happens to be that the weight init used for the Llama model in torchtrain does not depend on the original tensor shape (namely, the weight init is elementwise). However, this may not be the case for other models (e.g. those that compute fan-in/fan-out), in which case this approach would silently sample from the incorrect distribution.
FSDP1 calls
reset_parameters()
before sharding.FullyShardedDataParallel(module)
, FSDP1 callssubmodule.reset_parameters()
for each managedsubmodule
inmodule.modules()
(managed is defined by excluding any nestedFullyShardedDataParallel
modules or their children). This is the only way to ensure that each parameter is initialized exactly once.Attention
module re-initialized its Q/K/V linear modules, then FSDP1 would initialize the Q/K/V linears twice (once fromLinear.reset_parameters()
and once fromAttention.reset_parameters()
). This can still give a valid probability distribution, but it could give different values for a fixed seed compared to if theLinear.reset_parameters()
were skipped (e.g. if not using FSDP and just callingmodel.reset_parameters()
on the rootmodel
). This is not a major problem since it does not mean incorrect randomness but is still worth mentioning.model.reset_parameters()
after sharding with FSDP1, then we have 1D flattened sharded tensors, which no longer preserve the original tensor shape. Therefore, callingmodel.reset_parameters()
at this point will give incorrect randomness in cases depending on the shape.In summary, following the core guideline is the only way to guarantee that each parameter is initialized once and before sharding. The constraint to initialize once is not required for correct randomness but may help reproducibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, ok this make sense, so it is critical to only initialize it once for reproducibility when starting a fixed seed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the same time though, the DTensor RNG will be different than local, so I am not sure if this reproducibility argument makes sense. We would not be able to ensure same results for FSDP2 compared to a single-GPU non-
DTensor
setup.