Skip to content

Commit

Permalink
style(pre-commit): add pyupgrade
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Feb 4, 2025
1 parent 527f3b9 commit 51e6a79
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 58 deletions.
3 changes: 1 addition & 2 deletions docs/examples/custom_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"outputs": [],
"source": [
"import functools as ft\n",
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
Expand Down Expand Up @@ -72,7 +71,7 @@
"seconds = Dimension(\"s\")\n",
"\n",
"\n",
"def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:\n",
"def _dim_to_unit(x: Dimension | dict[Dimension, int]) -> dict[Dimension, int]:\n",
" if isinstance(x, Dimension):\n",
" return {x: 1}\n",
" else:\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/default_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"source": [
"import functools as ft\n",
"from collections.abc import Sequence\n",
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
Expand Down Expand Up @@ -64,7 +63,7 @@
" @staticmethod\n",
" def default(\n",
" primitive: jax.extend.core.Primitive,\n",
" values: Sequence[Union[ArrayLike, quax.Value]],\n",
" values: Sequence[ArrayLike | quax.Value],\n",
" params: dict,\n",
" ):\n",
" raw_values: list[ArrayLike] = []\n",
Expand Down
8 changes: 2 additions & 6 deletions docs/examples/redispatch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -77,9 +75,7 @@
"\n",
"\n",
"@quax.register(jax.lax.dot_general_p)\n",
"def _(\n",
" x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params\n",
"):\n",
"def _(x: LoraArray, y: ArrayLike | quax.ArrayValue, *, dimension_numbers, **params):\n",
" ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers\n",
" if jnp.ndim(x) != 2 and jnp.ndim(y) != 1:\n",
" raise NotImplementedError(\n",
Expand Down Expand Up @@ -254,7 +250,7 @@
"source": [
"@quax.register(jax.lax.dot_general_p)\n",
"def _(\n",
" x: Union[ArrayLike, quax.ArrayValue],\n",
" x: ArrayLike | quax.ArrayValue,\n",
" y: SomeKindOfSparseVector,\n",
" *,\n",
" dimension_numbers,\n",
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeCon
JAX_CHECK_TRACER_LEAKS = 1

[tool.ruff.lint]
select = ["E", "F", "I001"]
select = ["E", "F", "I001", "UP"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
fixable = ["I001", "F401"]
fixable = ["I001", "F401", "UP"]

[tool.ruff.lint.flake8-import-conventions.extend-aliases]
"jax.extend" = "jex"
Expand All @@ -60,3 +60,8 @@ order-by-type = false
[tool.pyright]
reportIncompatibleMethodOverride = true
include = ["quax", "tests"]

[dependency-groups]
dev = [
"pre-commit>=4.1.0",
]
15 changes: 7 additions & 8 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import functools as ft
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Generic, TypeVar, Union
from typing_extensions import TypeGuard
from typing import Any, cast, Generic, TypeGuard, TypeVar, Union

import equinox as eqx
import jax
Expand Down Expand Up @@ -285,7 +284,7 @@ def _unwrap_tracer(trace, x):

class _Quaxify(eqx.Module, Generic[CT]):
fn: CT
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]]
filter_spec: PyTree[bool | Callable[[Any], bool]]
dynamic: bool = eqx.field(static=True)

@property
Expand All @@ -310,15 +309,15 @@ def __call__(self, *args, **kwargs):
out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
return out

def __get__(self, instance: Union[object, None], owner: Any):
def __get__(self, instance: object | None, owner: Any):
if instance is None:
return self
return eqx.Partial(self, instance)


def quaxify(
fn: CT,
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = True,
filter_spec: PyTree[bool | Callable[[Any], bool]] = True,
) -> _Quaxify[CT]:
"""'Quaxifies' a function, so that it understands custom array-ish objects like
[`quax.examples.lora.LoraArray`][]. When this function is called, multiple dispatch
Expand Down Expand Up @@ -522,7 +521,7 @@ def aval(self) -> core.ShapedArray:


@register(jax._src.pjit.pjit_p) # pyright: ignore
def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
def _(*args: ArrayLike | ArrayValue, jaxpr, inline, **kwargs):
del kwargs
fun = quaxify(jex.core.jaxpr_as_fun(jaxpr))
if inline:
Expand All @@ -535,7 +534,7 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):

@register(jax.lax.while_p)
def _(
*args: Union[ArrayValue, ArrayLike],
*args: ArrayValue | ArrayLike,
cond_nconsts: int,
cond_jaxpr,
body_nconsts: int,
Expand Down Expand Up @@ -574,7 +573,7 @@ def _(
@register(jax.lax.cond_p)
def _(
index: ArrayLike,
*args: Union[ArrayValue, ArrayLike],
*args: ArrayValue | ArrayLike,
branches: tuple,
linear=_sentinel,
):
Expand Down
10 changes: 4 additions & 6 deletions quax/examples/lora/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import equinox as eqx
import jax.core
import jax.lax as lax
Expand Down Expand Up @@ -190,11 +188,11 @@ def _lora_array_matmul_impl(w, a, b, rhs, lhs_batch, ndim, dimension_numbers, kw
@quax.register(lax.dot_general_p)
def _lora_array_matmul(
lhs: LoraArray,
rhs: Union[ArrayLike, quax.ArrayValue],
rhs: ArrayLike | quax.ArrayValue,
*,
dimension_numbers,
**kwargs,
) -> Union[ArrayLike, quax.ArrayValue]:
) -> ArrayLike | quax.ArrayValue:
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
[ndim] = {lhs.a.ndim, lhs.b.ndim, lhs.w.ndim}
if lhs_contract == (ndim - 1,) and (ndim - 2 not in lhs_batch):
Expand Down Expand Up @@ -223,12 +221,12 @@ def _lora_array_matmul(

@quax.register(lax.dot_general_p)
def _(
lhs: Union[ArrayLike, quax.ArrayValue],
lhs: ArrayLike | quax.ArrayValue,
rhs: LoraArray,
*,
dimension_numbers,
**kwargs,
) -> Union[ArrayLike, quax.ArrayValue]:
) -> ArrayLike | quax.ArrayValue:
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
dimension_numbers_flipped = ((rhs_contract, lhs_contract), (rhs_batch, lhs_batch))
out = _lora_array_matmul(
Expand Down
8 changes: 4 additions & 4 deletions quax/examples/named/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from collections.abc import Callable
from typing import Any, Generic, Optional, TypeVar, Union
from typing import Any, Generic, TypeVar

import equinox as eqx
import jax.core
Expand All @@ -18,7 +18,7 @@ class Axis:
axis.
"""

size: Optional[int]
size: int | None


Axis.__init__.__doc__ = """**Arguments:**
Expand Down Expand Up @@ -109,14 +109,14 @@ def _(x: NamedArray, y: NamedArray) -> NamedArray:
return NamedArray(quax_op(x.array, y.array), axes)

@quax.register(prim)
def _(x: Union[ArrayLike, quax.ArrayValue], y: NamedArray) -> NamedArray:
def _(x: ArrayLike | quax.ArrayValue, y: NamedArray) -> NamedArray:
if quax.quaxify(jnp.shape)(x) == ():
return NamedArray(quax_op(x, y.array), y.axes)
else:
raise ValueError(f"Cannot apply {op} to non-scalar array and named array.")

@quax.register(prim)
def _(x: NamedArray, y: Union[ArrayLike, quax.ArrayValue]) -> NamedArray:
def _(x: NamedArray, y: ArrayLike | quax.ArrayValue) -> NamedArray:
if quax.quaxify(jnp.shape)(y) == ():
return NamedArray(quax_op(x.array, y), x.axes)
else:
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/prng/_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
from collections.abc import Sequence
from typing import Any, TypeVar
from typing_extensions import Self, TYPE_CHECKING, TypeAlias
from typing import Any, TYPE_CHECKING, TypeAlias, TypeVar
from typing_extensions import Self

import equinox as eqx
import jax
Expand Down
4 changes: 1 addition & 3 deletions quax/examples/structured_matrices/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import equinox as eqx
import jax.core
import jax.lax as lax
Expand Down Expand Up @@ -77,7 +75,7 @@ def _tridiagonal_matvec(
@quax.register(lax.dot_general_p)
def _(
lhs: TridiagonalMatrix,
rhs: Union[ArrayLike, quax.ArrayValue],
rhs: ArrayLike | quax.ArrayValue,
*,
dimension_numbers,
**kwargs,
Expand Down
4 changes: 1 addition & 3 deletions quax/examples/unitful/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import equinox as eqx # /~https://github.com/patrick-kidger/equinox
import jax
import jax.core as core
Expand All @@ -22,7 +20,7 @@ def __repr__(self):
seconds = Dimension("s")


def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:
def _dim_to_unit(x: Dimension | dict[Dimension, int]) -> dict[Dimension, int]:
if isinstance(x, Dimension):
return {x: 1}
else:
Expand Down
22 changes: 9 additions & 13 deletions quax/examples/zero/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools as ft
from typing import Any, get_args, Union
from typing import Any, get_args

import equinox as eqx
import jax.core
Expand Down Expand Up @@ -42,7 +42,7 @@ def materialise(self):
@quax.register(lax.broadcast_in_dim_p)
def _(
value: ArrayLike, *, broadcast_dimensions, shape, sharding=None
) -> Union[ArrayLike, quax.ArrayValue]:
) -> ArrayLike | quax.ArrayValue:
# Avoid an infinite loop using ensure_compile_time_eval.
with jax.ensure_compile_time_eval():
out = lax.broadcast_in_dim_p.bind(
Expand Down Expand Up @@ -87,16 +87,12 @@ def _shape_dtype(x, y, value):


@quax.register(lax.add_p)
def _(
x: Union[ArrayLike, quax.ArrayValue], y: Zero
) -> Union[ArrayLike, quax.ArrayValue]:
def _(x: ArrayLike | quax.ArrayValue, y: Zero) -> ArrayLike | quax.ArrayValue:
return _shape_dtype(x, y, value=x)


@quax.register(lax.add_p)
def _(
x: Zero, y: Union[ArrayLike, quax.ArrayValue]
) -> Union[ArrayLike, quax.ArrayValue]:
def _(x: Zero, y: ArrayLike | quax.ArrayValue) -> ArrayLike | quax.ArrayValue:
return _shape_dtype(x, y, value=y)


Expand All @@ -106,12 +102,12 @@ def _(x: Zero, y: Zero) -> Zero:


@quax.register(lax.mul_p)
def _(x: Union[ArrayLike, quax.ArrayValue], y: Zero) -> Zero:
def _(x: ArrayLike | quax.ArrayValue, y: Zero) -> Zero:
return _shape_dtype(x, y, value=y)


@quax.register(lax.mul_p)
def _(x: Zero, y: Union[ArrayLike, quax.ArrayValue]) -> Zero:
def _(x: Zero, y: ArrayLike | quax.ArrayValue) -> Zero:
return _shape_dtype(x, y, value=x)


Expand Down Expand Up @@ -151,12 +147,12 @@ def _zero_matmul(lhs, rhs, kwargs) -> Zero:


@quax.register(lax.dot_general_p)
def _(lhs: Zero, rhs: Union[ArrayLike, quax.ArrayValue], **kwargs) -> Zero:
def _(lhs: Zero, rhs: ArrayLike | quax.ArrayValue, **kwargs) -> Zero:
return _zero_matmul(lhs, rhs, kwargs)


@quax.register(lax.dot_general_p)
def _(lhs: Union[ArrayLike, quax.ArrayValue], rhs: Zero, **kwargs) -> Zero:
def _(lhs: ArrayLike | quax.ArrayValue, rhs: Zero, **kwargs) -> Zero:
return _zero_matmul(lhs, rhs, kwargs)


Expand All @@ -166,7 +162,7 @@ def _(lhs: Zero, rhs: Zero, **kwargs) -> Zero:


@quax.register(lax.integer_pow_p)
def _integer_pow(x: Zero, *, y: int) -> Union[Array, Zero]:
def _integer_pow(x: Zero, *, y: int) -> Array | Zero:
# Zero is a special case, because 0^0 = 1.
if y == 0:
return jnp.ones(x.shape, x.dtype) # pyright: ignore
Expand Down
10 changes: 3 additions & 7 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import jax
import jax.numpy as jnp
import pytest
Expand All @@ -8,7 +6,7 @@
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array):
def _true_fn(a: jax.Array):
return a + b

Expand Down Expand Up @@ -43,9 +41,7 @@ def test_cond_different_units():


def test_cond_different_out_trees():
def _outer_fn(
a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]
):
def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array):
def _true_fn(a: jax.Array):
return a + b

Expand Down Expand Up @@ -130,7 +126,7 @@ def test_cond_grad_closure():
def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
pred: Union[bool, jax.Array],
pred: bool | jax.Array,
):
def _true_fn_grad(a: jax.Array):
return a + outer_var
Expand Down

0 comments on commit 51e6a79

Please sign in to comment.