Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 training: clean up recipe names #1730

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def run(

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
Expand All @@ -358,7 +358,7 @@ def run(
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def test_linear_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_inductor_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
device = mesh.device_type

if rowwise:
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(
Expand Down
12 changes: 6 additions & 6 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def __post_init__(self):
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
class Float8LinearRecipeName(enum.Enum):
ALL_TENSORWISE = "all_tensorwise"
ALL_AXISWISE = "all_axiswise"
LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp"
TENSORWISE = "tensorwise"
ROWWISE = "rowwise"
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"


def recipe_name_to_linear_config(
Expand All @@ -339,11 +339,11 @@ def recipe_name_to_linear_config(
Output: a `Float8LinearConfig` configured to implement the recipe
"""

if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE:
if recipe_name is Float8LinearRecipeName.TENSORWISE:
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
return Float8LinearConfig()

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
elif recipe_name is Float8LinearRecipeName.ROWWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
Expand All @@ -363,7 +363,7 @@ def recipe_name_to_linear_config(
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down
Loading