diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 8487096e6c..ddc717f953 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -65,6 +65,8 @@ for _ in range(10): ## float8 linear with delayed scaling +:warning: We plan to deprecate delayed scaling in a future release, see /~https://github.com/pytorch/ao/issues/1680 for more details. + This is theoretically the most performant recipe as it minimizes memory reads. ```python diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..fb306e0fb7 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -304,6 +304,16 @@ def __post_init__(self): "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." ) + # Future deprecation warning for delayed scaling + if ( + self.cast_config_input.scaling_type != ScalingType.DYNAMIC + or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC + or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC + ): + logger.warning( + "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see /~https://github.com/pytorch/ao/issues/1680 for more details." + ) + # Pre-made recipes for common configurations # TODO(future PR): go through a round of design on this, and eventually expose