-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support power of 2 scaling factors in float8 training and use e4m3 ev…
…erywhere (#1670)
- Loading branch information
1 parent
bae41d1
commit 32a51ec
Showing
7 changed files
with
145 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import unittest | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_5: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
|
||
# source for notable single-precision cases: | ||
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format | ||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||
@pytest.mark.parametrize( | ||
"test_case", | ||
[ | ||
# ("test_case_name", input, expected result) | ||
("one", 1.0, 1.0), | ||
("inf", float("inf"), float("inf")), | ||
("nan", float("nan"), float("nan")), | ||
("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23), | ||
("largest normal number", 2**127 * (2 - 2**-23), float("inf")), | ||
("smallest positive normal number", 2**-126, 2**-126), | ||
("largest number less than one", 1.0 - 2**-24, 0.5), | ||
("smallest number larger than one", 1.0 + 2**-23, 1.0), | ||
# TODO(danielvegamyhre): debug why creating a tensor with largest | ||
# subnormal value in CI env for pytorch 2.5.1 truncates the value to 0. | ||
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), | ||
], | ||
) | ||
def test_round_scale_down_to_power_of_2_valid_inputs( | ||
test_case: dict, | ||
): | ||
test_case_name, input, expected_result = test_case | ||
input_tensor, expected_tensor = ( | ||
torch.tensor(input, dtype=torch.float32).cuda(), | ||
torch.tensor(expected_result, dtype=torch.float32).cuda(), | ||
) | ||
result = _round_scale_down_to_power_of_2(input_tensor) | ||
|
||
assert ( | ||
torch.equal(result, expected_tensor) | ||
or (result.isnan() and expected_tensor.isnan()) | ||
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"invalid_dtype", | ||
[ | ||
torch.bfloat16, | ||
torch.float16, | ||
torch.float64, | ||
torch.int8, | ||
torch.uint8, | ||
torch.int32, | ||
torch.uint32, | ||
torch.int64, | ||
], | ||
) | ||
def test_non_float32_input(invalid_dtype: torch.dtype): | ||
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) | ||
with pytest.raises(AssertionError, match="scale must be float32 tensor"): | ||
_round_scale_down_to_power_of_2(non_float32_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters