Skip to content

Commit

Permalink
add torch.nan_to_num and fix flaky torch.empty test
Browse files Browse the repository at this point in the history
  • Loading branch information
samuela committed Feb 15, 2025
1 parent 1c5cc83 commit 6fc563f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
27 changes: 24 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,30 @@ def test_arange():

def test_empty():
# torch.empty returns uninitialized values, so we need to multiply by 0 for deterministic, testable behavior.
t2j_function_test(lambda: 0 * torch.empty(()), [])
t2j_function_test(lambda: 0 * torch.empty(2), [])
t2j_function_test(lambda: 0 * torch.empty((2, 3)), [])
# NaNs are possible, so we need to convert them first. See
# https://discuss.pytorch.org/t/torch-empty-returns-nan/181389 and /~https://github.com/samuela/torch2jax/actions/runs/13348964668/job/37282967463.
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty(())), [])
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty(2)), [])
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty((2, 3))), [])


def test_nan_to_num():
# Test handling of NaN values
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("nan"), 1.0, 2.0])), [])

# Test handling of positive infinity
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("inf"), 1.0, 2.0])), [])

# Test handling of negative infinity
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("-inf"), 1.0, 2.0])), [])

# Test handling of all special values with custom replacements
t2j_function_test(
lambda: torch.nan_to_num(
torch.tensor([float("nan"), float("inf"), float("-inf")]), nan=0.0, posinf=1.0, neginf=-1.0
),
[],
)


def test_ones():
Expand Down
1 change: 1 addition & 0 deletions torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def fn(*args, **kwargs):


auto_implements(torch.abs, jnp.abs)
auto_implements(torch.nan_to_num, jnp.nan_to_num)
auto_implements(torch.add, jnp.add)
auto_implements(torch.exp, jnp.exp)
auto_implements(torch.nn.functional.gelu, jax.nn.gelu)
Expand Down

0 comments on commit 6fc563f

Please sign in to comment.