Skip to content

Commit

Permalink
float8 training: make the "config from recipe" API polished
Browse files Browse the repository at this point in the history
Summary:

This PR makes the API that takes a recipe name (enum or string) and
returns a `Float8LinearConfig` instance more polished and ready for
usage in README.md docs and by partner callsites such as torchtitan and
torchtune.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4f72eeb19603d6e1203fa9bf6ce8235bf431ecad
ghstack-comment-id: 2667010633
Pull Request resolved: #1731
  • Loading branch information
vkuzo committed Feb 18, 2025
1 parent 169f112 commit e1ecae8
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 95 deletions.
5 changes: 2 additions & 3 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
ScalingType,
convert_to_float8_training,
)
from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config
from torchao.float8.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
Expand Down Expand Up @@ -349,7 +348,7 @@ def run(

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
config = Float8LinearConfig.from_recipe_name("rowwise")
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
Expand All @@ -358,7 +357,7 @@ def run(
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
# config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp")
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@

from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
Float8LinearRecipeName,
Float8LinearConfig,
ScalingType,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
Expand Down Expand Up @@ -311,8 +310,7 @@ def main(
emulate=False,
)
elif recipe_name is not None:
recipe_name = Float8LinearRecipeName(recipe_name)
config = recipe_name_to_linear_config(recipe_name)
config = Float8LinearConfig.from_recipe_name(recipe_name)

scaling_repr = "_".join(
[
Expand Down
3 changes: 1 addition & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand Down Expand Up @@ -442,7 +441,7 @@ def test_linear_from_recipe(
linear_dtype = torch.bfloat16
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
config = recipe_name_to_linear_config(recipe_name)
config = Float8LinearConfig.from_recipe_name(recipe_name)
self._test_linear_impl(
x,
m_ref,
Expand Down
3 changes: 1 addition & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
Float8LinearRecipeName,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand Down Expand Up @@ -227,7 +226,7 @@ def test_inductor_from_config_params(
)
def test_inductor_from_recipe(recipe_name):
torch._dynamo.reset()
config = recipe_name_to_linear_config(recipe_name)
config = Float8LinearConfig.from_recipe_name(recipe_name)
fullgraph = True
dtype = torch.bfloat16
_test_compile_base(
Expand Down
3 changes: 1 addition & 2 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
Float8LinearRecipeName,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
Expand Down Expand Up @@ -198,7 +197,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
device = mesh.device_type

if rowwise:
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
Expand Down
3 changes: 1 addition & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
Expand Down Expand Up @@ -210,7 +209,7 @@ def test_encoder_fw_bw_from_recipe(
self,
recipe_name: str,
):
config = recipe_name_to_linear_config(recipe_name)
config = Float8LinearConfig.from_recipe_name(recipe_name)
self._test_impl(config)


Expand Down
166 changes: 86 additions & 80 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import enum
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -146,6 +146,15 @@ class Float8GemmConfig:
use_fast_accum: bool = False


# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
class Float8LinearRecipeName(enum.Enum):
TENSORWISE = "tensorwise"
ROWWISE = "rowwise"
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Expand Down Expand Up @@ -321,86 +330,83 @@ def __post_init__(self):
"Note: delayed and static scaling will be deprecated in a future release of torchao. Please see /~https://github.com/pytorch/ao/issues/1680 for more details."
)

@staticmethod
def from_recipe_name(
recipe_name: Union[Float8LinearRecipeName, str],
) -> "Float8LinearConfig":
"""
Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value
Output: a `Float8LinearConfig` configured to implement the specified recipe
"""
if type(recipe_name) == str:
valid_names = [n.value for n in Float8LinearRecipeName]
assert (
recipe_name in valid_names
), f"recipe_name {recipe_name} not in valid names {valid_names}"
recipe_name = Float8LinearRecipeName(recipe_name)

if recipe_name is Float8LinearRecipeName.TENSORWISE:
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
return Float8LinearConfig()

elif recipe_name is Float8LinearRecipeName.ROWWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)

# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
class Float8LinearRecipeName(enum.Enum):
TENSORWISE = "tensorwise"
ROWWISE = "rowwise"
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"
return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
# grad_weight_hp = input_t_hp @ grad_output_hp
#
# key characteristics:
# * increased accuracy for grad_weight
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels
# * the e4m3 dtype is used across the board, including for gradients

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

def recipe_name_to_linear_config(
recipe_name: Float8LinearRecipeName,
) -> Float8LinearConfig:
"""
Input: `Float8LinearRecipeName` value
Output: a `Float8LinearConfig` configured to implement the recipe
"""
# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
)

if recipe_name is Float8LinearRecipeName.TENSORWISE:
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
return Float8LinearConfig()

elif recipe_name is Float8LinearRecipeName.ROWWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
# grad_weight_hp = input_t_hp @ grad_output_hp
#
# key characteristics:
# * increased accuracy for grad_weight
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels
# * the e4m3 dtype is used across the board, including for gradients

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
cast_config_input_for_grad_weight=cc_i_gw,
cast_config_weight_for_grad_input=cc_w_gi,
cast_config_grad_output_for_grad_weight=cc_go_gw,
)

else:
raise AssertionError(f"unknown recipe_name {recipe_name}")
return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
cast_config_input_for_grad_weight=cc_i_gw,
cast_config_weight_for_grad_input=cc_w_gi,
cast_config_grad_output_for_grad_weight=cc_go_gw,
)

else:
raise AssertionError(f"unknown recipe_name {recipe_name}")

0 comments on commit e1ecae8

Please sign in to comment.