diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2b3f631d8c..9bd4206d76 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -349,7 +349,7 @@ def run( # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() - config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) @@ -358,7 +358,7 @@ def run( # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() - # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) + # config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP) # m_fp8_lw = convert_to_float8_training(m_orig, config=config) # m_fp8_lw = torch.compile(m_fp8_lw) # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index b537c7ab9f..055b3f3054 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -420,8 +420,8 @@ def test_linear_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index d9c71f7395..83ec188192 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -218,8 +218,8 @@ def test_inductor_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @unittest.skipIf( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index d0f34da0a9..d71e23b6b2 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base( device = mesh.device_type if rowwise: - config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) # hack around config being frozen # TODO(future PR): we should make this nicer at the config level object.__setattr__(config, "emulate", True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 311964d831..e47d4310b4 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @pytest.mark.skipif( diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b971ff31b0..c1720ea70c 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -326,9 +326,9 @@ def __post_init__(self): # TODO(future PR): go through a round of design on this, and eventually expose # as a top level public API. class Float8LinearRecipeName(enum.Enum): - ALL_TENSORWISE = "all_tensorwise" - ALL_AXISWISE = "all_axiswise" - LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + TENSORWISE = "tensorwise" + ROWWISE = "rowwise" + ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" def recipe_name_to_linear_config( @@ -339,11 +339,11 @@ def recipe_name_to_linear_config( Output: a `Float8LinearConfig` configured to implement the recipe """ - if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE: + if recipe_name is Float8LinearRecipeName.TENSORWISE: # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel return Float8LinearConfig() - elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: + elif recipe_name is Float8LinearRecipeName.ROWWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel cc_i = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype @@ -363,7 +363,7 @@ def recipe_name_to_linear_config( round_scales_to_power_of_2=True, ) - elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: + elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: # # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1