Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Selective layer activation checkpointing, single control for turning AC on or off. #125

Merged
merged 10 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,18 @@ def init_args_from_command_line(
], # TODO: add "delayed" option back in when supported
help="Type of fp8 linear quantization to apply to the model",
)

# activation checkpointing
parser.add_argument(
"--training.enable_selective_ac",
action="store_true",
help="whether to enable selective activation checkpointing",
"--activation_checkpoint.mode",
type=str,
default="selective",
help=" ['none', 'full', 'selective'] = type of activation checkpointing to use",
)
parser.add_argument(
"--activation_checkpoint.selective_ac_option",
type=str,
default="2", # 2 = checkpoint every other layer
help="['int', 'op'] = selective activation checkpointing options, 'int' for every nth layer, or 'op' for op level ac.",
)
return parser.parse_args(args_list)
64 changes: 51 additions & 13 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# this file applies the PTD parallelisms and various training techniques to the
# llama model, i.e. activation checkpoint, etc.
# llama model, i.e. activation checkpointing, etc.

import logging
from collections import defaultdict
Expand Down Expand Up @@ -33,6 +33,9 @@
PrepareModuleInput,
RowwiseParallel,
)

from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint

from torchtrain.config_manager import JobConfig
from torchtrain.logging_utils import rank0_log
from torchtrain.meta_init import meta_to_real_init_fn
Expand Down Expand Up @@ -77,12 +80,9 @@ def partition_fn(name, module, device_mesh):
}

# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, enable_selective_ac):
if enable_selective_ac:
from torch.utils.checkpoint import (
_pt2_selective_checkpoint_context_fn_gen,
checkpoint,
)
# currently selective per op and per layer checkpointing are supported
def checkpoint_wrapper(module, config):
if config.mode == "selective" and config.selective_ac_option == "op":

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
Expand All @@ -108,13 +108,50 @@ def selective_checkpointing_context_fn():
use_reentrant=False,
preserve_rng_state=False,
)
else:
elif config.mode == "full":
# full AC
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)

elif config.mode == "selective" and config.selective_ac_option.isdigit():

"""enables selective checkpointing of candidate layers.
Usage:
'selective_ac_option' with an 'int' value in config controls which layers to checkpoint.
None, 0 == checkpointing filtering not active, checkpoint all instances
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
every_x_layer = int(config.selective_ac_option)
assert (
every_x_layer >= 0
), f"selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}"

checkpoint_wrapper.__dict__.setdefault("_count", 0)

checkpoint_wrapper._count += 1
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)
# skip activation checkpointing and store activations for this layer
else:
return module

else:
raise NotImplementedError(
"Unknown AC type or AC config. Only selective op and selective layer ac implemented currently."
)


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Expand Down Expand Up @@ -216,11 +253,12 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):

# apply AC/selective AC
transformer_block = checkpoint_wrapper(
transformer_block, job_config.training.enable_selective_ac
)
# apply AC to the transformer block
if job_config.activation_checkpoint.mode in ("full", "selective"):
# wrap the transformer block with checkpoint wrapper, using config settings
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
4 changes: 4 additions & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)

[activation_checkpoint]
mode = 'selective' # ['none', 'full', 'selective']
selective_ac_option = '2' # 'int' = ac every int layer or 'op', ac based on ops policy
Loading