Skip to content

Commit

Permalink
Used per-parameter FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Mar 26, 2024
1 parent 8dd5798 commit e9a9c11
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 194 deletions.
48 changes: 0 additions & 48 deletions torchtrain/meta_init.py

This file was deleted.

89 changes: 34 additions & 55 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))

# re-enable if not using meta-init
# self.reset_parameters()
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

def forward(
self,
Expand Down Expand Up @@ -309,19 +298,10 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class RotaryEmbedding(nn.Module):
Expand All @@ -333,13 +313,13 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.init_weights()

self.freqs_cis = precompute_freqs_cis(
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
def _precompute_freqs_cis(self):
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
)

Expand All @@ -359,6 +339,16 @@ def forward(self, tokens: torch.Tensor):
freqs_cis = self.freqs_cis[0:seqlen]
return h, freqs_cis

def init_weights(self):
if hasattr(self, "freqs_cis"):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
else:
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(), persistent=False
)
nn.init.normal_(self.tok_embeddings.weight)


class TransformerBlock(nn.Module):
"""
Expand Down Expand Up @@ -400,6 +390,7 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
else:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
self.init_weights()

def forward(
self,
Expand All @@ -421,13 +412,11 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)
def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down Expand Up @@ -457,29 +446,19 @@ def __init__(self, model_args: ModelArgs):
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

# init model weights

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()

def reset_parameters(
self,
):
def init_weights(self):
for layer in self.layers:
layer.reset_parameters()
layer.init_weights()
self.norm.reset_parameters()
self.embeddings.init_weights()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
Expand Down
105 changes: 39 additions & 66 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand All @@ -33,7 +27,6 @@

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


# for selective AC
Expand Down Expand Up @@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
preserve_rng_state=False,
)
elif config.mode == "full":
# full AC
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down Expand Up @@ -136,28 +128,23 @@ def get_tp_parallel_strategy(

def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Apply parallelisms and activation checkpointing to the model.
NOTE: the model passed in preferrablably shoule be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# apply PTD parallelisms
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Tensor Parallelism if it's enabled
if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
tp_degree = job_config.training.tensor_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. parallelize the root norm layer by sequence dim
# 3. shard the first layer of transformer block
# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
output_layouts=(
Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate()
),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
Expand All @@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# apply tensor + sequence parallelism to every transformer block
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
Expand All @@ -203,62 +192,46 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.info("Applied Sequence Parallelism to the model")
logger.info("Applied Tensor Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to the transformer block
if ac_mode in ("full", "selective"):
# wrap the transformer block with checkpoint wrapper, using config settings
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model)

fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
Loading

0 comments on commit e9a9c11

Please sign in to comment.