Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP with nf4 #1684

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/dtypes/ddp/check_ddp_nf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
from pathlib import Path

import torch

from torchao.dtypes.nf4tensor import NF4Tensor

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ref_checkpoint_dir", type=str, required=True)
parser.add_argument("--test_checkpoints_dir", type=str, required=True)

args = parser.parse_args()

ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt"))
assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint"
ref_checkpoint = ref_checkpoints[0]
ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu")
print(f"Ref checkpoint: {ref_checkpoint}")

for path in Path(args.test_checkpoints_dir).glob("*.pt"):
print(f"Checking {path}")
state_dict = torch.load(path, weights_only=True, map_location="cpu")
assert ref_state_dict.keys() == state_dict.keys()
for name in ref_state_dict.keys():
ref_param = ref_state_dict[name]
test_param = state_dict[name]
print(f"Checking {name} {type(ref_param)} {type(test_param)}")

if isinstance(ref_param, NF4Tensor):
ref_param = ref_param.get_original_weight()
assert isinstance(test_param, NF4Tensor)
test_param = test_param.get_original_weight()

if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4):
diff = (ref_param - test_param).abs().max()
print(f" \u2718 Param {name} differs by {diff}")
else:
print(f" \u2713 Param {name} is consistent")
print("Passed!")
155 changes: 155 additions & 0 deletions test/dtypes/ddp/ddp_nf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import argparse
import math
import os
import time
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torch.nn as nn
from torch._dynamo import config as dynamo_config
from torch.nn.parallel import DistributedDataParallel as DDP

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4


class LoRALinear(nn.Module):
def __init__(
self,
hidden_dim: int,
lora_rank: int = None,
lora_alpha: float = 16,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.hidden_dim = hidden_dim
if lora_rank is None:
lora_rank = hidden_dim // 2

weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype)
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.register_parameter(
"weight", nn.Parameter(to_nf4(weight), requires_grad=False)
)
self.lora_a = nn.Linear(
in_features=hidden_dim, out_features=self.lora_rank, bias=False
)
self.lora_b = nn.Linear(
in_features=self.lora_rank, out_features=hidden_dim, bias=False
)
self.initialize_parameters()

def initialize_parameters(self):
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = linear_nf4(input=x, weight=self.weight)
lora_out = self.lora_a(x)
lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out)
return out + lora_out


def _init_model(dim, num_linears, device, dtype) -> nn.Module:
with torch.device(device):
modules = []
for i in range(num_linears):
modules += [LoRALinear(hidden_dim=dim, dtype=dtype)]
seq = nn.Sequential(*modules)

return seq


def dist_print(*args, delay=0.5):
rank = dist.get_rank()
time.sleep(delay * rank)
print(f"[rank{rank}]: ", *args, flush=True)


def make_batch(global_bs, dim, dtype, device):
batch = torch.randn((global_bs, dim), dtype=dtype, device=device)
if dist.get_world_size() > 1:
batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()]
return batch


def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile):
os.makedirs(save_dir, exist_ok=True)
model = _init_model(dim, num_linears, device, dtype)
model = DDP(model, device_ids=[device])

if compile:
model = torch.compile(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

losses = []

for i in range(num_steps):
inp = make_batch(global_bs, dim, dtype, device)
loss = model(inp).sum()
losses.append(loss)
loss.backward()
optim.step()
optim.zero_grad()

dist.barrier()

save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt"
torch.save(model.state_dict(), save_path)
dist_print("Saved model to", save_path)


def init_dist():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
dist_print("Dist initialized with world size", dist.get_world_size())


def cleanup_dist():
dist.barrier()
if dist.get_rank() == 0:
print("Cleaning up dist")
dist.destroy_process_group()


@contextmanager
def distributed_context():
init_dist()
yield
cleanup_dist()


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--global_bs", type=int, default=8)
parser.add_argument("--dim", type=int, default=128)
parser.add_argument("--num_linears", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--num_steps", type=int, default=3)
parser.add_argument("--save_dir", type=str, default="checkpoints")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer")
args = parser.parse_args()

args.dtype = getattr(torch, args.dtype)
dynamo_config.optimize_ddp = args.optimize_ddp

if args.optimize_ddp == "python_reducer":
dynamo_config.compiled_autograd = True

with distributed_context():
torch.manual_seed(args.seed)
run_ddp(
global_bs=args.global_bs,
dim=args.dim,
num_linears=args.num_linears,
device=args.device,
dtype=args.dtype,
num_steps=args.num_steps,
save_dir=args.save_dir,
compile=args.compile,
)
48 changes: 48 additions & 0 deletions test/dtypes/ddp/run_ddp_nf4_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

set -euo pipefail
WORLD_SIZE=${1:-2}


# Test params
GLOBAL_BS=8
DIM=128
NUM_LINEARS=1
NUM_STEPS=3

PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS"
SAVE_DIR="checkpoints"
REF_DIR="${SAVE_DIR}/ref"
TEST_DIR="${SAVE_DIR}/test"
DDP_PROGRAM="ddp_nf4.py"
CHECK_PROGRAM="check_ddp_nf4.py"
REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR"
TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR"
CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR"
CLEANUP_CMD="rm -rf $SAVE_DIR"

echo "Step 1: Generating reference checkpoint..."
echo $REF_CMD
$REF_CMD
echo -e "\n --- \n"
sleep 2

echo "Step 2: Generating test checkpoints..."
echo $TEST_CMD
$TEST_CMD
echo -e "\n --- \n"
sleep 2

# Check params
echo "Step 3: Checking params..."
echo $CHECK_CMD
$CHECK_CMD
echo -e "\n --- \n"
sleep 2

# Cleanup
echo "Step 4: Cleaning up..."
echo $CLEANUP_CMD
$CLEANUP_CMD
echo -e "\n --- \n"
echo "Done!"
84 changes: 57 additions & 27 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def nf4_split(aten_op, args, kwargs=None):
attr_to_chunks = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4tensor, attr)
assert (
inner_tensor.numel() % num_chunks == 0
), f"{attr}.numel() not divisible by {num_chunks}"
assert inner_tensor.numel() % num_chunks == 0, (
f"{attr}.numel() not divisible by {num_chunks}"
)
chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs)
attr_to_chunks[attr] = chunks

Expand Down Expand Up @@ -236,9 +236,9 @@ def nf4_new_zeros(aten_op, args, kwargs=None):
updated_attrs = {}
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4tensor, attr)
assert (
inner_tensor.size(0) % ratio == 0
), f"{attr}.numel() must be divisible by {ratio}"
assert inner_tensor.size(0) % ratio == 0, (
f"{attr}.numel() must be divisible by {ratio}"
)
inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs)
updated_attrs[attr] = inner_tensor
updated_attrs["size"] = new_size
Expand Down Expand Up @@ -423,6 +423,35 @@ def nf4_pin_memory(aten_op, args, kwargs=None):
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


@implements(
[
aten.cat.default,
]
)
def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
tensors_to_cat = args[0]
assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat)
remaining_args = args[1:]

ts = []
for t in tensors_to_cat:
assert isinstance(t, torch.Tensor)

if isinstance(t, NF4Tensor):
ts.append(t.get_original_weight())
else:
ts.append(t)

dtype = ts[0].dtype
assert all(t.dtype == dtype for t in ts)

if kwargs is None:
kwargs = {}

tensors = aten_op(ts, *remaining_args, **kwargs)
return tensors


@dataclass(frozen=True)
class SubclassTensorArgs:
original_shape: torch.Size
Expand All @@ -444,9 +473,9 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso
torch.Tensor: Tensor of scalers for each block
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
(input_tensor.numel() % block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
assert (input_tensor.numel() % block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
)

n_blocks = input_tensor.numel() // block_size
blocks = input_tensor.view(n_blocks, block_size)
Expand Down Expand Up @@ -529,12 +558,12 @@ def from_tensor(
block_size: int,
scaler_block_size: int,
):
assert (
input_tensor.dim() <= 2
), f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}"
assert (
input_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
assert input_tensor.dim() <= 2, (
f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}"
)
assert input_tensor.numel() % block_size == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
)
assert input_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
# assert not input_tensor.requires_grad, "Input tensor must not require grad"
Expand Down Expand Up @@ -615,19 +644,19 @@ def double_quantize_scalers(
size: (n_scaler_blocks)
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
assert (input_tensor.numel() % scaler_block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
)

# First round of quantization
# Produces: A tensor of size (n_blocks) of input_tensor.dtype
scalers_1 = get_block_absmax(input_tensor, block_size)
scalers_1_mean = scalers_1.mean()
scalers_1 = scalers_1 - scalers_1_mean
# Second round of quantization
assert (
scalers_1.numel() % scaler_block_size == 0
), f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
assert scalers_1.numel() % scaler_block_size == 0, (
f"Number of scalers must be divisible by scaler block size, got {scalers_1.numel()} scaler_block_size {scaler_block_size} "
)
n_scaler_blocks = scalers_1.numel() // scaler_block_size
scaler_blocks = scalers_1.view(n_scaler_blocks, scaler_block_size)

Expand Down Expand Up @@ -669,9 +698,9 @@ def dequantize_scalers(

"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
assert (input_tensor.numel() % scaler_block_size) == 0, (
f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
)
n_scaler_blocks = input_tensor.numel() // scaler_block_size
input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
Expand All @@ -687,9 +716,9 @@ def convert_to_norm_float_weight(
flattened_tensor = input_tensor.flatten()
# Since we are using uint8 we will encode 2 entries per byte
numel = input_tensor.numel()
assert (
numel % 2 == 0
), "Number of elements must be even just to not have to think about the end"
assert numel % 2 == 0, (
"Number of elements must be even just to not have to think about the end"
)
# Reshape the flattened tensor into blocks of size self.block_size
blocks = flattened_tensor.view(n_blocks, block_size)

Expand Down Expand Up @@ -1058,3 +1087,4 @@ def nf4_constructor(

if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([NF4Tensor])
torch.serialization.add_safe_globals([NF4Tensor])
Loading