Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Used per-parameter FSDP #165

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions torchtrain/meta_init.py

This file was deleted.

97 changes: 41 additions & 56 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))

# re-enable if not using meta-init
# self.reset_parameters()
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I have some confusions about the reset_parameters guideline, so reset_parameters is an optional method in nn.Module, and it does not "recursively" call into the submodule's reset_parameters call when calling the parent module's reset_parameters().

This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?

And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters() on their defined root module's reset_parameters() function and not worrying about the attention layer wq/wk/wv be overriden by the builtin nn.Linear's reset_parameter call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?

Not sure if you get my question haha, am I missing sth there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that if the guideline is that each module should ONLY be responsible to its own parameter, user have to loop all the submodules in the module tree and call them individually?

This is my understanding.

And if that's the case, if user decide to not recursively loop submodules, one can simply define reset_parameters to re-init its own parameters + its leaf module parameters just like we did previously (i.e. nn.Attention we can also re-init the q/k/v linears), so that user can simply call reset_parameters() on their defined root module's reset_parameters() function and not worrying about the attention layer wq/wk/wv be overriden by the builtin nn.Linear's reset_parameter call, since it would never call that. This might be sth user already doing as they might want to control how the submodule init works themselves?

I agree with the approach you are mentioning

  • if we ignore FSDP
  • if we are using FSDP1 and every weight init does not depend on the original tensor shape

It happens to be that the weight init used for the Llama model in torchtrain does not depend on the original tensor shape (namely, the weight init is elementwise). However, this may not be the case for other models (e.g. those that compute fan-in/fan-out), in which case this approach would silently sample from the incorrect distribution.

FSDP1 calls reset_parameters() before sharding.

  • The current approach is aligned with the core guideline, so for FullyShardedDataParallel(module), FSDP1 calls submodule.reset_parameters() for each managed submodule in module.modules() (managed is defined by excluding any nested FullyShardedDataParallel modules or their children). This is the only way to ensure that each parameter is initialized exactly once.
  • If a parent Attention module re-initialized its Q/K/V linear modules, then FSDP1 would initialize the Q/K/V linears twice (once from Linear.reset_parameters() and once from Attention.reset_parameters()). This can still give a valid probability distribution, but it could give different values for a fixed seed compared to if the Linear.reset_parameters() were skipped (e.g. if not using FSDP and just calling model.reset_parameters() on the root model). This is not a major problem since it does not mean incorrect randomness but is still worth mentioning.
  • If we further call model.reset_parameters() after sharding with FSDP1, then we have 1D flattened sharded tensors, which no longer preserve the original tensor shape. Therefore, calling model.reset_parameters() at this point will give incorrect randomness in cases depending on the shape.

In summary, following the core guideline is the only way to guarantee that each parameter is initialized once and before sharding. The constraint to initialize once is not required for correct randomness but may help reproducibility.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, ok this make sense, so it is critical to only initialize it once for reproducibility when starting a fixed seed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the same time though, the DTensor RNG will be different than local, so I am not sure if this reproducibility argument makes sense. We would not be able to ensure same results for FSDP2 compared to a single-GPU non-DTensor setup.

for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

def forward(
self,
Expand Down Expand Up @@ -309,19 +298,10 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class RotaryEmbedding(nn.Module):
Expand All @@ -333,13 +313,15 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(), persistent=False
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
def _precompute_freqs_cis(self):
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
)

Expand All @@ -355,10 +337,14 @@ def forward(self, tokens: torch.Tensor):
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
return h, freqs_cis

def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
nn.init.normal_(self.tok_embeddings.weight)


class TransformerBlock(nn.Module):
"""
Expand Down Expand Up @@ -421,13 +407,11 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)
def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down Expand Up @@ -457,28 +441,29 @@ def __init__(self, model_args: ModelArgs):
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

# init model weights

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()

def reset_parameters(
self,
):
def init_weights(self):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
``reset_parameters`` is meant to only initialize directly owned
parameters/buffers, not those of their child modules, and it can be
used to give the initial values for these tensors.
Separately, users may want custom initialization for their modules,
different from that in ``reset_parameters``. For this, we define
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
self.embeddings.init_weights()
for layer in self.layers:
layer.reset_parameters()
layer.init_weights()
self.norm.reset_parameters()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
Expand Down
106 changes: 40 additions & 66 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand All @@ -33,7 +27,6 @@

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


# for selective AC
Expand Down Expand Up @@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
preserve_rng_state=False,
)
elif config.mode == "full":
# full AC
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down Expand Up @@ -136,28 +128,23 @@ def get_tp_parallel_strategy(

def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Apply parallelisms and activation checkpointing to the model.

NOTE: the model passed in preferrablably shoule be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# apply PTD parallelisms
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Tensor Parallelism if it's enabled
if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
tp_degree = job_config.training.tensor_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. parallelize the root norm layer by sequence dim
# 3. shard the first layer of transformer block
# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
output_layouts=(
Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate()
),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
Expand All @@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# apply tensor + sequence parallelism to every transformer block
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
Expand All @@ -203,62 +192,47 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.info("Applied Sequence Parallelism to the model")
logger.info("Applied Tensor Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to the transformer block
if ac_mode in ("full", "selective"):
# wrap the transformer block with checkpoint wrapper, using config settings
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model)

fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to not including this 'trick' since it might be confusing. The idea is that we basically can reshard_after_forward=False for the last transformer block for free.

# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
Loading
Loading