From 51e6a790059dfbf464b1f771aa3442fa3becc246 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 28 Jan 2025 00:01:41 -0500 Subject: [PATCH] style(pre-commit): add pyupgrade Signed-off-by: Nathaniel Starkman --- docs/examples/custom_rules.ipynb | 3 +-- docs/examples/default_rules.ipynb | 3 +-- docs/examples/redispatch.ipynb | 8 ++------ pyproject.toml | 9 +++++++-- quax/_core.py | 15 +++++++-------- quax/examples/lora/_core.py | 10 ++++------ quax/examples/named/_core.py | 8 ++++---- quax/examples/prng/_core.py | 4 ++-- quax/examples/structured_matrices/_core.py | 4 +--- quax/examples/unitful/_core.py | 4 +--- quax/examples/zero/_core.py | 22 +++++++++------------- tests/test_cond.py | 10 +++------- 12 files changed, 42 insertions(+), 58 deletions(-) diff --git a/docs/examples/custom_rules.ipynb b/docs/examples/custom_rules.ipynb index f82eeab..6d7ab79 100644 --- a/docs/examples/custom_rules.ipynb +++ b/docs/examples/custom_rules.ipynb @@ -20,7 +20,6 @@ "outputs": [], "source": [ "import functools as ft\n", - "from typing import Union\n", "\n", "import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n", "import jax\n", @@ -72,7 +71,7 @@ "seconds = Dimension(\"s\")\n", "\n", "\n", - "def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:\n", + "def _dim_to_unit(x: Dimension | dict[Dimension, int]) -> dict[Dimension, int]:\n", " if isinstance(x, Dimension):\n", " return {x: 1}\n", " else:\n", diff --git a/docs/examples/default_rules.ipynb b/docs/examples/default_rules.ipynb index 35806d5..c488f92 100644 --- a/docs/examples/default_rules.ipynb +++ b/docs/examples/default_rules.ipynb @@ -24,7 +24,6 @@ "source": [ "import functools as ft\n", "from collections.abc import Sequence\n", - "from typing import Union\n", "\n", "import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n", "import jax\n", @@ -64,7 +63,7 @@ " @staticmethod\n", " def default(\n", " primitive: jax.extend.core.Primitive,\n", - " values: Sequence[Union[ArrayLike, quax.Value]],\n", + " values: Sequence[ArrayLike | quax.Value],\n", " params: dict,\n", " ):\n", " raw_values: list[ArrayLike] = []\n", diff --git a/docs/examples/redispatch.ipynb b/docs/examples/redispatch.ipynb index 75bef2d..998083c 100644 --- a/docs/examples/redispatch.ipynb +++ b/docs/examples/redispatch.ipynb @@ -19,8 +19,6 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Union\n", - "\n", "import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n", "import jax\n", "import jax.numpy as jnp\n", @@ -77,9 +75,7 @@ "\n", "\n", "@quax.register(jax.lax.dot_general_p)\n", - "def _(\n", - " x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params\n", - "):\n", + "def _(x: LoraArray, y: ArrayLike | quax.ArrayValue, *, dimension_numbers, **params):\n", " ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers\n", " if jnp.ndim(x) != 2 and jnp.ndim(y) != 1:\n", " raise NotImplementedError(\n", @@ -254,7 +250,7 @@ "source": [ "@quax.register(jax.lax.dot_general_p)\n", "def _(\n", - " x: Union[ArrayLike, quax.ArrayValue],\n", + " x: ArrayLike | quax.ArrayValue,\n", " y: SomeKindOfSparseVector,\n", " *,\n", " dimension_numbers,\n", diff --git a/pyproject.toml b/pyproject.toml index 66cb424..46c41b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,9 @@ addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeCon JAX_CHECK_TRACER_LEAKS = 1 [tool.ruff.lint] -select = ["E", "F", "I001"] +select = ["E", "F", "I001", "UP"] ignore = ["E402", "E721", "E731", "E741", "F722"] -fixable = ["I001", "F401"] +fixable = ["I001", "F401", "UP"] [tool.ruff.lint.flake8-import-conventions.extend-aliases] "jax.extend" = "jex" @@ -60,3 +60,8 @@ order-by-type = false [tool.pyright] reportIncompatibleMethodOverride = true include = ["quax", "tests"] + +[dependency-groups] +dev = [ + "pre-commit>=4.1.0", +] diff --git a/quax/_core.py b/quax/_core.py index 0416b71..af4c2a0 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -2,8 +2,7 @@ import functools as ft import itertools as it from collections.abc import Callable, Sequence -from typing import Any, cast, Generic, TypeVar, Union -from typing_extensions import TypeGuard +from typing import Any, cast, Generic, TypeGuard, TypeVar, Union import equinox as eqx import jax @@ -285,7 +284,7 @@ def _unwrap_tracer(trace, x): class _Quaxify(eqx.Module, Generic[CT]): fn: CT - filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] + filter_spec: PyTree[bool | Callable[[Any], bool]] dynamic: bool = eqx.field(static=True) @property @@ -310,7 +309,7 @@ def __call__(self, *args, **kwargs): out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out) return out - def __get__(self, instance: Union[object, None], owner: Any): + def __get__(self, instance: object | None, owner: Any): if instance is None: return self return eqx.Partial(self, instance) @@ -318,7 +317,7 @@ def __get__(self, instance: Union[object, None], owner: Any): def quaxify( fn: CT, - filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = True, + filter_spec: PyTree[bool | Callable[[Any], bool]] = True, ) -> _Quaxify[CT]: """'Quaxifies' a function, so that it understands custom array-ish objects like [`quax.examples.lora.LoraArray`][]. When this function is called, multiple dispatch @@ -522,7 +521,7 @@ def aval(self) -> core.ShapedArray: @register(jax._src.pjit.pjit_p) # pyright: ignore -def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): +def _(*args: ArrayLike | ArrayValue, jaxpr, inline, **kwargs): del kwargs fun = quaxify(jex.core.jaxpr_as_fun(jaxpr)) if inline: @@ -535,7 +534,7 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): @register(jax.lax.while_p) def _( - *args: Union[ArrayValue, ArrayLike], + *args: ArrayValue | ArrayLike, cond_nconsts: int, cond_jaxpr, body_nconsts: int, @@ -574,7 +573,7 @@ def _( @register(jax.lax.cond_p) def _( index: ArrayLike, - *args: Union[ArrayValue, ArrayLike], + *args: ArrayValue | ArrayLike, branches: tuple, linear=_sentinel, ): diff --git a/quax/examples/lora/_core.py b/quax/examples/lora/_core.py index bd01d85..54a3ad6 100644 --- a/quax/examples/lora/_core.py +++ b/quax/examples/lora/_core.py @@ -1,5 +1,3 @@ -from typing import Union - import equinox as eqx import jax.core import jax.lax as lax @@ -190,11 +188,11 @@ def _lora_array_matmul_impl(w, a, b, rhs, lhs_batch, ndim, dimension_numbers, kw @quax.register(lax.dot_general_p) def _lora_array_matmul( lhs: LoraArray, - rhs: Union[ArrayLike, quax.ArrayValue], + rhs: ArrayLike | quax.ArrayValue, *, dimension_numbers, **kwargs, -) -> Union[ArrayLike, quax.ArrayValue]: +) -> ArrayLike | quax.ArrayValue: ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers [ndim] = {lhs.a.ndim, lhs.b.ndim, lhs.w.ndim} if lhs_contract == (ndim - 1,) and (ndim - 2 not in lhs_batch): @@ -223,12 +221,12 @@ def _lora_array_matmul( @quax.register(lax.dot_general_p) def _( - lhs: Union[ArrayLike, quax.ArrayValue], + lhs: ArrayLike | quax.ArrayValue, rhs: LoraArray, *, dimension_numbers, **kwargs, -) -> Union[ArrayLike, quax.ArrayValue]: +) -> ArrayLike | quax.ArrayValue: ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers dimension_numbers_flipped = ((rhs_contract, lhs_contract), (rhs_batch, lhs_batch)) out = _lora_array_matmul( diff --git a/quax/examples/named/_core.py b/quax/examples/named/_core.py index 691ddd2..314382b 100644 --- a/quax/examples/named/_core.py +++ b/quax/examples/named/_core.py @@ -1,6 +1,6 @@ import dataclasses from collections.abc import Callable -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, TypeVar import equinox as eqx import jax.core @@ -18,7 +18,7 @@ class Axis: axis. """ - size: Optional[int] + size: int | None Axis.__init__.__doc__ = """**Arguments:** @@ -109,14 +109,14 @@ def _(x: NamedArray, y: NamedArray) -> NamedArray: return NamedArray(quax_op(x.array, y.array), axes) @quax.register(prim) - def _(x: Union[ArrayLike, quax.ArrayValue], y: NamedArray) -> NamedArray: + def _(x: ArrayLike | quax.ArrayValue, y: NamedArray) -> NamedArray: if quax.quaxify(jnp.shape)(x) == (): return NamedArray(quax_op(x, y.array), y.axes) else: raise ValueError(f"Cannot apply {op} to non-scalar array and named array.") @quax.register(prim) - def _(x: NamedArray, y: Union[ArrayLike, quax.ArrayValue]) -> NamedArray: + def _(x: NamedArray, y: ArrayLike | quax.ArrayValue) -> NamedArray: if quax.quaxify(jnp.shape)(y) == (): return NamedArray(quax_op(x.array, y), x.axes) else: diff --git a/quax/examples/prng/_core.py b/quax/examples/prng/_core.py index 2603300..e4eb785 100644 --- a/quax/examples/prng/_core.py +++ b/quax/examples/prng/_core.py @@ -1,8 +1,8 @@ import abc import functools as ft from collections.abc import Sequence -from typing import Any, TypeVar -from typing_extensions import Self, TYPE_CHECKING, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias, TypeVar +from typing_extensions import Self import equinox as eqx import jax diff --git a/quax/examples/structured_matrices/_core.py b/quax/examples/structured_matrices/_core.py index c40904e..7ca087d 100644 --- a/quax/examples/structured_matrices/_core.py +++ b/quax/examples/structured_matrices/_core.py @@ -1,5 +1,3 @@ -from typing import Union - import equinox as eqx import jax.core import jax.lax as lax @@ -77,7 +75,7 @@ def _tridiagonal_matvec( @quax.register(lax.dot_general_p) def _( lhs: TridiagonalMatrix, - rhs: Union[ArrayLike, quax.ArrayValue], + rhs: ArrayLike | quax.ArrayValue, *, dimension_numbers, **kwargs, diff --git a/quax/examples/unitful/_core.py b/quax/examples/unitful/_core.py index e4e5e0f..75e2e47 100644 --- a/quax/examples/unitful/_core.py +++ b/quax/examples/unitful/_core.py @@ -1,5 +1,3 @@ -from typing import Union - import equinox as eqx # /~https://github.com/patrick-kidger/equinox import jax import jax.core as core @@ -22,7 +20,7 @@ def __repr__(self): seconds = Dimension("s") -def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]: +def _dim_to_unit(x: Dimension | dict[Dimension, int]) -> dict[Dimension, int]: if isinstance(x, Dimension): return {x: 1} else: diff --git a/quax/examples/zero/_core.py b/quax/examples/zero/_core.py index e87c15c..8b94b5a 100644 --- a/quax/examples/zero/_core.py +++ b/quax/examples/zero/_core.py @@ -1,5 +1,5 @@ import functools as ft -from typing import Any, get_args, Union +from typing import Any, get_args import equinox as eqx import jax.core @@ -42,7 +42,7 @@ def materialise(self): @quax.register(lax.broadcast_in_dim_p) def _( value: ArrayLike, *, broadcast_dimensions, shape, sharding=None -) -> Union[ArrayLike, quax.ArrayValue]: +) -> ArrayLike | quax.ArrayValue: # Avoid an infinite loop using ensure_compile_time_eval. with jax.ensure_compile_time_eval(): out = lax.broadcast_in_dim_p.bind( @@ -87,16 +87,12 @@ def _shape_dtype(x, y, value): @quax.register(lax.add_p) -def _( - x: Union[ArrayLike, quax.ArrayValue], y: Zero -) -> Union[ArrayLike, quax.ArrayValue]: +def _(x: ArrayLike | quax.ArrayValue, y: Zero) -> ArrayLike | quax.ArrayValue: return _shape_dtype(x, y, value=x) @quax.register(lax.add_p) -def _( - x: Zero, y: Union[ArrayLike, quax.ArrayValue] -) -> Union[ArrayLike, quax.ArrayValue]: +def _(x: Zero, y: ArrayLike | quax.ArrayValue) -> ArrayLike | quax.ArrayValue: return _shape_dtype(x, y, value=y) @@ -106,12 +102,12 @@ def _(x: Zero, y: Zero) -> Zero: @quax.register(lax.mul_p) -def _(x: Union[ArrayLike, quax.ArrayValue], y: Zero) -> Zero: +def _(x: ArrayLike | quax.ArrayValue, y: Zero) -> Zero: return _shape_dtype(x, y, value=y) @quax.register(lax.mul_p) -def _(x: Zero, y: Union[ArrayLike, quax.ArrayValue]) -> Zero: +def _(x: Zero, y: ArrayLike | quax.ArrayValue) -> Zero: return _shape_dtype(x, y, value=x) @@ -151,12 +147,12 @@ def _zero_matmul(lhs, rhs, kwargs) -> Zero: @quax.register(lax.dot_general_p) -def _(lhs: Zero, rhs: Union[ArrayLike, quax.ArrayValue], **kwargs) -> Zero: +def _(lhs: Zero, rhs: ArrayLike | quax.ArrayValue, **kwargs) -> Zero: return _zero_matmul(lhs, rhs, kwargs) @quax.register(lax.dot_general_p) -def _(lhs: Union[ArrayLike, quax.ArrayValue], rhs: Zero, **kwargs) -> Zero: +def _(lhs: ArrayLike | quax.ArrayValue, rhs: Zero, **kwargs) -> Zero: return _zero_matmul(lhs, rhs, kwargs) @@ -166,7 +162,7 @@ def _(lhs: Zero, rhs: Zero, **kwargs) -> Zero: @quax.register(lax.integer_pow_p) -def _integer_pow(x: Zero, *, y: int) -> Union[Array, Zero]: +def _integer_pow(x: Zero, *, y: int) -> Array | Zero: # Zero is a special case, because 0^0 = 1. if y == 0: return jnp.ones(x.shape, x.dtype) # pyright: ignore diff --git a/tests/test_cond.py b/tests/test_cond.py index c940fb2..52505dd 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -1,5 +1,3 @@ -from typing import Union - import jax import jax.numpy as jnp import pytest @@ -8,7 +6,7 @@ from quax.examples.unitful import kilograms, meters, Unitful -def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]): +def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array): def _true_fn(a: jax.Array): return a + b @@ -43,9 +41,7 @@ def test_cond_different_units(): def test_cond_different_out_trees(): - def _outer_fn( - a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array] - ): + def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array): def _true_fn(a: jax.Array): return a + b @@ -130,7 +126,7 @@ def test_cond_grad_closure(): def outer_fn( outer_var: jax.Array, dummy: jax.Array, - pred: Union[bool, jax.Array], + pred: bool | jax.Array, ): def _true_fn_grad(a: jax.Array): return a + outer_var