diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 9bd4206d76..684ed0af2a 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,6 @@ ScalingType, convert_to_float8_training, ) -from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config from torchao.float8.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, @@ -349,7 +348,7 @@ def run( # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() - config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) + config = Float8LinearConfig.from_recipe_name("rowwise") m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) @@ -358,7 +357,7 @@ def run( # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() - # config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP) + # config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp") # m_fp8_lw = convert_to_float8_training(m_orig, config=config) # m_fp8_lw = torch.compile(m_fp8_lw) # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 5045956954..687684d4e2 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -39,9 +39,8 @@ from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( - Float8LinearRecipeName, + Float8LinearConfig, ScalingType, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -311,8 +310,7 @@ def main( emulate=False, ) elif recipe_name is not None: - recipe_name = Float8LinearRecipeName(recipe_name) - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 055b3f3054..156c8abe87 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -32,7 +32,6 @@ ScalingType, e4m3_dtype, e5m2_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -442,7 +441,7 @@ def test_linear_from_recipe( linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_linear_impl( x, m_ref, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 83ec188192..0c02db26a6 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -33,7 +33,6 @@ Float8LinearRecipeName, ScalingType, e4m3_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -227,7 +226,7 @@ def test_inductor_from_config_params( ) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) fullgraph = True dtype = torch.bfloat16 _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index d71e23b6b2..886cc2a504 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -41,7 +41,6 @@ Float8LinearRecipeName, ScalingType, e4m3_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic @@ -198,7 +197,7 @@ def _test_fp8_mlp_tensor_parallelism_base( device = mesh.device_type if rowwise: - config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) + config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) # hack around config being frozen # TODO(future PR): we should make this nicer at the config level object.__setattr__(config, "emulate", True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e47d4310b4..01e4cbb20d 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -28,7 +28,6 @@ Float8LinearConfig, Float8LinearRecipeName, ScalingType, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -210,7 +209,7 @@ def test_encoder_fw_bw_from_recipe( self, recipe_name: str, ): - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_impl(config) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c1720ea70c..f2be4849a8 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -7,7 +7,7 @@ import enum import logging from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch @@ -146,6 +146,15 @@ class Float8GemmConfig: use_fast_accum: bool = False +# Pre-made recipes for common configurations +# TODO(future PR): go through a round of design on this, and eventually expose +# as a top level public API. +class Float8LinearRecipeName(enum.Enum): + TENSORWISE = "tensorwise" + ROWWISE = "rowwise" + ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -321,86 +330,83 @@ def __post_init__(self): "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see /~https://github.com/pytorch/ao/issues/1680 for more details." ) + @staticmethod + def from_recipe_name( + recipe_name: Union[Float8LinearRecipeName, str], + ) -> "Float8LinearConfig": + """ + Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value + Output: a `Float8LinearConfig` configured to implement the specified recipe + """ + if type(recipe_name) == str: + valid_names = [n.value for n in Float8LinearRecipeName] + assert ( + recipe_name in valid_names + ), f"recipe_name {recipe_name} not in valid names {valid_names}" + recipe_name = Float8LinearRecipeName(recipe_name) + + if recipe_name is Float8LinearRecipeName.TENSORWISE: + # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel + return Float8LinearConfig() + + elif recipe_name is Float8LinearRecipeName.ROWWISE: + # dynamic axiswise scaling with the CUTLASS rowwise kernel + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) -# Pre-made recipes for common configurations -# TODO(future PR): go through a round of design on this, and eventually expose -# as a top level public API. -class Float8LinearRecipeName(enum.Enum): - TENSORWISE = "tensorwise" - ROWWISE = "rowwise" - ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + round_scales_to_power_of_2=True, + ) + elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: + # lw's recipe for a modification on all-axiswise: + # + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `input`, `weight` and `grad_output` now only need to be scaled + # axiswise across a single dim compared to vanilla all-axiswise, + # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients + + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) -def recipe_name_to_linear_config( - recipe_name: Float8LinearRecipeName, -) -> Float8LinearConfig: - """ - Input: `Float8LinearRecipeName` value - Output: a `Float8LinearConfig` configured to implement the recipe - """ + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) - if recipe_name is Float8LinearRecipeName.TENSORWISE: - # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel - return Float8LinearConfig() - - elif recipe_name is Float8LinearRecipeName.ROWWISE: - # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_w = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_go = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - - return Float8LinearConfig( - cast_config_input=cc_i, - cast_config_weight=cc_w, - cast_config_grad_output=cc_go, - # enable power of 2 scaling factors by default for row-wise scaling - round_scales_to_power_of_2=True, - ) - - elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: - # lw's recipe for a modification on all-axiswise: - # - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - # grad_weight_hp = input_t_hp @ grad_output_hp - # - # key characteristics: - # * increased accuracy for grad_weight - # * `input`, `weight` and `grad_output` now only need to be scaled - # axiswise across a single dim compared to vanilla all-axiswise, - # which is more amenable to fast kernels - # * the e4m3 dtype is used across the board, including for gradients - - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) - - # grad_weight_hp = input_t_hp @ grad_output_hp - cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig( - scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype - ) - - return Float8LinearConfig( - cast_config_input=cc_i, - cast_config_weight=cc_w, - cast_config_grad_output=cc_go, - cast_config_input_for_grad_weight=cc_i_gw, - cast_config_weight_for_grad_input=cc_w_gi, - cast_config_grad_output_for_grad_weight=cc_go_gw, - ) - - else: - raise AssertionError(f"unknown recipe_name {recipe_name}") + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + ) + + else: + raise AssertionError(f"unknown recipe_name {recipe_name}")