Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 18, 2025
1 parent 988c5c9 commit 8da2519
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
4 changes: 2 additions & 2 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def run(

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
Expand All @@ -358,7 +358,7 @@ def run(
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def test_linear_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_inductor_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
device = mesh.device_type

if rowwise:
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE)
# hack around config being frozen
# TODO(future PR): we should make this nicer at the config level
object.__setattr__(config, "emulate", True)
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params(
@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(
Expand Down
12 changes: 6 additions & 6 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def __post_init__(self):
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
class Float8LinearRecipeName(enum.Enum):
ALL_TENSORWISE = "all_tensorwise"
ALL_AXISWISE = "all_axiswise"
LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp"
TENSORWISE = "tensorwise"
ROWWISE = "rowwise"
ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"


def recipe_name_to_linear_config(
Expand All @@ -339,11 +339,11 @@ def recipe_name_to_linear_config(
Output: a `Float8LinearConfig` configured to implement the recipe
"""

if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE:
if recipe_name is Float8LinearRecipeName.TENSORWISE:
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
return Float8LinearConfig()

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
elif recipe_name is Float8LinearRecipeName.ROWWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
Expand All @@ -363,7 +363,7 @@ def recipe_name_to_linear_config(
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down

0 comments on commit 8da2519

Please sign in to comment.