Skip to content

Commit

Permalink
Support power of 2 scaling factors in float8 training and use e4m3 ev…
Browse files Browse the repository at this point in the history
…erywhere (#1670)
  • Loading branch information
danielvegamyhre authored Feb 10, 2025
1 parent bae41d1 commit 32a51ec
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 21 deletions.
6 changes: 5 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def test_transpose(self):

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_axiswise_dynamic_cast(
self, shape, axiswise_dim, round_scales_to_power_of_2
):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
Expand Down
20 changes: 14 additions & 6 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_utils import config_has_stateful_scaling
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success():
torch.float16,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
@pytest.mark.parametrize(
"round_scales_to_power_of_2",
[
True,
False,
],
)
def test_dynamic_scale_numeric_parity(
dtype: torch.dtype, round_scales_to_power_of_2: bool
):
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
hp_tensor2 = hp_tensor1.detach().clone()
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
linear_mm_config = LinearMMConfig(
# output
Expand Down Expand Up @@ -456,13 +462,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)
Expand Down
65 changes: 65 additions & 0 deletions test/float8/test_float8_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import pytest
import torch

from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


# source for notable single-precision cases:
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@pytest.mark.parametrize(
"test_case",
[
# ("test_case_name", input, expected result)
("one", 1.0, 1.0),
("inf", float("inf"), float("inf")),
("nan", float("nan"), float("nan")),
("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23),
("largest normal number", 2**127 * (2 - 2**-23), float("inf")),
("smallest positive normal number", 2**-126, 2**-126),
("largest number less than one", 1.0 - 2**-24, 0.5),
("smallest number larger than one", 1.0 + 2**-23, 1.0),
# TODO(danielvegamyhre): debug why creating a tensor with largest
# subnormal value in CI env for pytorch 2.5.1 truncates the value to 0.
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
],
)
def test_round_scale_down_to_power_of_2_valid_inputs(
test_case: dict,
):
test_case_name, input, expected_result = test_case
input_tensor, expected_tensor = (
torch.tensor(input, dtype=torch.float32).cuda(),
torch.tensor(expected_result, dtype=torch.float32).cuda(),
)
result = _round_scale_down_to_power_of_2(input_tensor)

assert (
torch.equal(result, expected_tensor)
or (result.isnan() and expected_tensor.isnan())
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"


@pytest.mark.parametrize(
"invalid_dtype",
[
torch.bfloat16,
torch.float16,
torch.float64,
torch.int8,
torch.uint8,
torch.int32,
torch.uint32,
torch.int64,
],
)
def test_non_float32_input(invalid_dtype: torch.dtype):
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
_round_scale_down_to_power_of_2(non_float32_tensor)
21 changes: 18 additions & 3 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class Float8LinearConfig:
# tests so that the warning does not spam the CI stdout.
force_recompute_fp8_weight_in_bwd: bool = False

# If this option is enabled, the scaling factor used for float8 quantization
# will be rounded down to the nearest power of 2. This has been shown to help
# reduce quantization error by avoiding rounding errors when multiplying/dividing
# by the scaling factor, as well as ensuring large values are quantized to the
# same value in the forward pass as the backward passes.
round_scales_to_power_of_2: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -338,14 +345,22 @@ def recipe_name_to_linear_config(

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_input.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand All @@ -112,6 +113,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

# the reshapes are needed in order to make the shapes compatible with
Expand Down Expand Up @@ -151,6 +153,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_grad_output.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand Down Expand Up @@ -181,6 +184,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_weight_for_grad_input.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

grad_input = torch.mm(
Expand Down Expand Up @@ -216,6 +220,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(input_hp_reshaped):
Expand All @@ -233,6 +238,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_input_for_grad_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

grad_weight = torch.mm(
Expand Down
4 changes: 4 additions & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)


# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly
def hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
Expand All @@ -36,6 +37,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
round_scales_to_power_of_2: bool = False,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -51,6 +53,7 @@ def hp_tensor_to_float8_dynamic(
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
scale = tensor_to_scale(
hp_tensor,
Expand All @@ -59,6 +62,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh,
scaling_granularity,
axiswise_dim,
round_scales_to_power_of_2,
)
return hp_tensor_and_scale_to_float8(
hp_tensor,
Expand Down
44 changes: 33 additions & 11 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import torch.distributed as dist
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce

from torchao.float8.config import (
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -33,21 +29,28 @@


@torch.no_grad()
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
amax = amax.to(torch.float64)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
res = res.to(torch.float32)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

return res.to(torch.float32)
if round_scales_to_power_of_2:
res = _round_scale_down_to_power_of_2(res)
return res


@torch.no_grad()
Expand Down Expand Up @@ -119,21 +122,35 @@ def tensor_to_amax(

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor,
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
round_scales_to_power_of_2: bool = False,
) -> torch.Tensor:
"""
Compute scaling factor for the given high precision tensor.
Args:
hp_tensor: high precision tensor
float8_dtype: the float8 dtype to use
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
amax = tensor_to_amax(
x,
hp_tensor,
reduce_amax,
device_mesh,
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype)
return amax_to_scale(
amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2
)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down Expand Up @@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC
or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
)


def _round_scale_down_to_power_of_2(scale: torch.Tensor):
assert scale.dtype == torch.float32, "scale must be float32 tensor"
return torch.exp2(torch.floor(torch.log2(scale)))

0 comments on commit 32a51ec

Please sign in to comment.