Skip to content

Commit

Permalink
Used per-parameter FSDP (#165)
Browse files Browse the repository at this point in the history
**Numeric Parity**
1D FSDP
- Eager: 1k steps of minipile on 8 H100 GPUs, local batch size 8,
sequence length 2048, AC/SAC, bf16 mixed precision, fp32 reduce-scatter
- FSDP1 (AC): 24.81% peak active, 33.82% peak reserved, 6100-6200 WPS
- FSDP1 (SAC): 52.98% peak active, 67.23% peak reserved, 6500-6700 WPS
- FSDP2 (AC): 23.92% peak active, 32.64% peak reserved, 6100-6300 WPS
- FSDP2 (SAC): 52.13% peak active, 62.51% peak reserved, 6600-6800 WPS
    - Loss curves match between FSDP1 and FSDP2
- Memory numbers reported as percentage since that is how they are
logged; can convert against 95.0396 GiB GPU memory
- Compile: same setup as eager
- FSDP2 (AC), buffer reuse disabled: 28.72 GiB (30.22%) peak reserved,
7200-7500 WPS, 33% MFU
- FSDP2 (AC), buffer reuse enabled: 28.90 GiB (30.40%) peak reserved,
7200-7500 WPS, 33% MFU
- FSDP2 (SAC), buffer reuse enabled: 53.83 GiB (56.64%) peak reserved,
8100-8400 WPS, 36% MFU
    - Loss curves slightly better than eager
    - For fun -- how much can we push MFU?
- If we use FSDP2 (SAC) with 16 local batch size (doubled), we get 88.23
GiB (92.84%) peak reserved, 8600 WPS, 38% MFU.
- If we use FSDP2 (no AC) with 8 local batch size, we get 90.28 GiB
(94.99%) peak reserved, 9100-9300 WPS, 40% MFU.
- Why is FSDP2 faster? (1) fp32 reduce-scatter only uses one div kernel
instead of two and (2), `reshard_after_forward=False` for the last
transformer block

2D FSDP
- Eager (2-way SP, 4-way FSDP): 1k steps of minipile on 8 H100 GPUs,
local batch size 16 (to preserve global batch size), sequence length
2048, bf16 mixed precision, fp32 reduce-scatter
- FSDP2 (AC): 50.12% peak active, 60.97% peak reserved, 5800-5900 WPS
- FSDP2 (SAC): 76.49% peak active, 90.14% peak reserved, 6100-6300 WPS
- Loss curves match 8-way FSDP
- FSDP1 + SP has incorrect numerics due to the `FSDP.clip_grad_norm_`
not all-reducing over TP mesh dimension

<details>
<summary> Loss curves </summary>

<img width="732" alt="Screenshot 2024-03-26 at 3 31 19 PM"
src="/~https://github.com/pytorch/torchtrain/assets/31054793/59ec71cc-ad0a-4dd1-b5c6-a8cbf9ab5e85">

</details>


**Meta-Device Initialization**
- The PyTorch Core guideline is for `module.reset_parameters()` to only
initialize parameters/buffers immediately owned by `module` (i.e.
`module.parameters(recurse=False)` and `module.buffers(recurse=False)`).
- This makes it challenging to specify custom initializations for core
modules like `nn.Linear` and `nn.Embedding`. For example, in
@lessw2020's depth-wise truncated normal initialization, the
`trunc_normal_` standard deviation depends on the layer ID, which is a
property of the `TransformerBlock` but affects the child `nn.Linear`s.
- To disambiguate, I suggest avoiding the name `reset_parameters()` in
the case that we violate the PyTorch Core guideline and instead use a
different name (e.g. `init_weights`).

**DCP & Save/Load**
- Tested 1D and 2D by specifying `checkpoint_folder =
"/tmp/checkpoint_andgu` in the `.toml`, training until saving a
checkpoint, terminating the run, and restarting the training to load the
checkpoint -- the loss after loading looks reasonable
  • Loading branch information
awgu authored Mar 28, 2024
1 parent ef7f67c commit 6d3d906
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 194 deletions.
48 changes: 0 additions & 48 deletions torchtrain/meta_init.py

This file was deleted.

97 changes: 41 additions & 56 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))

# re-enable if not using meta-init
# self.reset_parameters()
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Expand Down Expand Up @@ -207,19 +205,10 @@ def __init__(self, model_args: ModelArgs):
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)

def reset_parameters(self, init_std):
for item in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=0.02,
)

nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)

def forward(
self,
Expand Down Expand Up @@ -309,19 +298,10 @@ def __init__(
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def reset_parameters(self, init_std):
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=0.02,
)

for item in (self.w2, self.w3):
nn.init.trunc_normal_(
item.weight,
mean=0.0,
std=init_std,
)
def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class RotaryEmbedding(nn.Module):
Expand All @@ -333,13 +313,15 @@ def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(), persistent=False
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
# of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
# or fine-tuning.
def _precompute_freqs_cis(self):
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
)

Expand All @@ -355,10 +337,14 @@ def forward(self, tokens: torch.Tensor):
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[0:seqlen]
return h, freqs_cis

def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
nn.init.normal_(self.tok_embeddings.weight)


class TransformerBlock(nn.Module):
"""
Expand Down Expand Up @@ -421,13 +407,11 @@ def forward(
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
"""reset params and norms for entire block"""
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()

self.attention.reset_parameters(self.weight_init_std)
self.feed_forward.reset_parameters(self.weight_init_std)
def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down Expand Up @@ -457,28 +441,29 @@ def __init__(self, model_args: ModelArgs):
self.model_dim = model_args.dim

self.embeddings = RotaryEmbedding(model_args)

self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

# init model weights

# we are doing meta_init, which will call reset_parameters() after
# the model is moved to actual device.
# If you modify and are not using meta_init, you will need to call
# reset_parameters() manually as below:

# self.reset_parameters()

def reset_parameters(
self,
):
def init_weights(self):
"""
[Note: On ``init_weights`` vs. ``reset_parameters``]
Modules may define ``reset_parameters`` to initialize parameter values.
``reset_parameters`` is meant to only initialize directly owned
parameters/buffers, not those of their child modules, and it can be
used to give the initial values for these tensors.
Separately, users may want custom initialization for their modules,
different from that in ``reset_parameters``. For this, we define
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
self.embeddings.init_weights()
for layer in self.layers:
layer.reset_parameters()
layer.init_weights()
self.norm.reset_parameters()
final_out_std = self.model_dim**-0.5
cutoff_factor = 3
Expand Down
106 changes: 40 additions & 66 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@
from typing import Tuple

import torch
from torch.distributed._tensor import Replicate, Shard

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand All @@ -33,7 +27,6 @@

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import logger
from torchtrain.meta_init import meta_to_real_init_fn


# for selective AC
Expand Down Expand Up @@ -75,7 +68,6 @@ def selective_checkpointing_context_fn():
preserve_rng_state=False,
)
elif config.mode == "full":
# full AC
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down Expand Up @@ -136,28 +128,23 @@ def get_tp_parallel_strategy(

def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Apply parallelisms and activation checkpointing to the model.
NOTE: the model passed in preferrablably shoule be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# apply PTD parallelisms
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Tensor Parallelism if it's enabled
if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
tp_degree = job_config.training.tensor_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. parallelize the root norm layer by sequence dim
# 3. shard the first layer of transformer block
# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -167,9 +154,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate(),
output_layouts=(
Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate()
),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
Expand All @@ -181,7 +170,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# apply tensor + sequence parallelism to every transformer block
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
Expand All @@ -203,62 +192,47 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# adjust num_heads in attention layer to local heads
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.info("Applied Sequence Parallelism to the model")
logger.info("Applied Tensor Parallelism to the model")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
# TODO: see whether we should expose a option to user
reduce_dtype=torch.float32,
),
"sharding_strategy": ShardingStrategy.FULL_SHARD,
"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_to_real_init_fn,
}

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to the transformer block
if ac_mode in ("full", "selective"):
# wrap the transformer block with checkpoint wrapper, using config settings
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model)

fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
else:
meta_to_real_init_fn(model)
model.cuda()

# we have now moved from meta to device,
# reset parameters for proper initialization
model.reset_parameters()
logger.info("Model fully initialized via reset_parameters")

return model
Loading

0 comments on commit 6d3d906

Please sign in to comment.