Skip to content

Commit

Permalink
float8 training: clean up recipe names
Browse files Browse the repository at this point in the history
Summary:

Originally these recipe names were added with the intention of polishing
the API later. Later is now, this PR cleans up the names to make them
easier to understand.

Test Plan:

```
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 3b94ff31a06b0ed554e4cdf764be424fd2d3a3cf
ghstack-comment-id: 2666973711
Pull Request resolved: #1730
  • Loading branch information
vkuzo committed Feb 18, 2025
1 parent 988c5c9 commit 169f112
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def run(

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
Expand All @@ -358,7 +358,7 @@ def run(
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def test_linear_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_inductor_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
device = mesh.device_type

if rowwise:
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(
Expand Down
12 changes: 6 additions & 6 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def __post_init__(self):
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
class Float8LinearRecipeName(enum.Enum):
ALL_TENSORWISE = "all_tensorwise"
ALL_AXISWISE = "all_axiswise"
LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp"
TENSORWISE = "tensorwise"
ROWWISE = "rowwise"
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"


def recipe_name_to_linear_config(
Expand All @@ -339,11 +339,11 @@ def recipe_name_to_linear_config(
Output: a `Float8LinearConfig` configured to implement the recipe
"""

if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE:
if recipe_name is Float8LinearRecipeName.TENSORWISE:
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
return Float8LinearConfig()

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
elif recipe_name is Float8LinearRecipeName.ROWWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
Expand All @@ -363,7 +363,7 @@ def recipe_name_to_linear_config(
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down

0 comments on commit 169f112

Please sign in to comment.