Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style(isort): equinox and jaxtyping are first-party #47

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/examples/custom_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
"import functools as ft\n",
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/jaxtyping\n",
"\n",
"import quax"
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import quax\n",
"from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/jaxtyping"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/default_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
"from collections.abc import Sequence\n",
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import jax.tree_util as jtu\n",
"from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/quax\n",
"\n",
"import quax"
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import quax\n",
"from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/quax"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/redispatch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
"source": [
"from typing import Union\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import equinox as eqx # /~https://github.com/patrick-kidger/equinox\n",
"import quax\n",
"from jaxtyping import ( # /~https://github.com/patrick-kidger/quax\n",
" Array,\n",
" ArrayLike,\n",
" Int,\n",
" Shaped,\n",
")\n",
"\n",
"import quax"
")"
]
},
{
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ dependencies = [
"plum-dispatch>=2.2.1",
]

[dependency-groups]
dev = [
"ipykernel>=6.29.5",
"pre-commit>=4.1.0",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand All @@ -38,6 +44,7 @@ build-backend = "hatchling.build"
include = ["quax/*"]

[tool.pytest.ini_options]
minversion = 8.0
addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"

[tool.pytest_env]
Expand All @@ -55,6 +62,7 @@ fixable = ["I001", "F401"]
combine-as-imports = true
lines-after-imports = 2
extra-standard-library = ["typing_extensions"]
known-first-party = ["equinox", "jaxtyping"]
order-by-type = false

[tool.pyright]
Expand Down
3 changes: 2 additions & 1 deletion quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, cast, Generic, TypeVar, Union
from typing_extensions import TypeGuard

import equinox as eqx
import jax
import jax._src
import jax.core as core
Expand All @@ -15,6 +14,8 @@
import jax.tree_util as jtu
import plum
from jax.custom_derivatives import SymbolicZero as SZ

import equinox as eqx
from jaxtyping import ArrayLike, PyTree


Expand Down
4 changes: 2 additions & 2 deletions quax/examples/lora/_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Union

import equinox as eqx
import jax.core
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree, Shaped

import equinox as eqx
import quax
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree, Shaped


class LoraArray(quax.ArrayValue):
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/named/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from collections.abc import Callable
from typing import Any, Generic, Optional, TypeVar, Union

import equinox as eqx
import jax.core
import jax.extend as jex
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import ArrayLike

import equinox as eqx
import quax
from jaxtyping import ArrayLike


@dataclasses.dataclass(frozen=True, eq=False)
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/prng/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from typing import Any, TypeVar
from typing_extensions import Self, TYPE_CHECKING, TypeAlias

import equinox as eqx
import jax
import jax._src.prng
import jax.core
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import Array, ArrayLike, Float, Integer, UInt, UInt32

import equinox as eqx
import quax
from jaxtyping import Array, ArrayLike, Float, Integer, UInt, UInt32


RealArray: TypeAlias = ArrayLike
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/sparse/_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import get_args

import equinox as eqx
import jax.core
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Integer, Shaped

import equinox as eqx
import quax
from jaxtyping import Array, ArrayLike, Integer, Shaped


class BCOO(quax.ArrayValue):
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/structured_matrices/_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Union

import equinox as eqx
import jax.core
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Shaped

import equinox as eqx
import quax
from jaxtyping import Array, ArrayLike, Shaped


class TridiagonalMatrix(quax.ArrayValue):
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/unitful/_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Union

import equinox as eqx # /~https://github.com/patrick-kidger/equinox
import jax
import jax.core as core
import jax.numpy as jnp
from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/jaxtyping

import equinox as eqx # /~https://github.com/patrick-kidger/equinox
import quax
from jaxtyping import ArrayLike # /~https://github.com/patrick-kidger/jaxtyping


class Dimension:
Expand Down
4 changes: 2 additions & 2 deletions quax/examples/zero/_core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import functools as ft
from typing import Any, get_args, Union

import equinox as eqx
import jax.core
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike

import equinox as eqx
import quax
from jaxtyping import Array, ArrayLike


class Zero(quax.ArrayValue):
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import equinox.internal as eqxi
import jax.random as jr
import pytest


@pytest.fixture()
def getkey():
return eqxi.GetKey()
return lambda: jr.PRNGKey(0)
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import cast

import equinox as eqx
import jax
import jax.core
import jax.lax as lax
import jax.numpy as jnp
import pytest
from jaxtyping import Array

import equinox as eqx
import quax
from jaxtyping import Array


def test_jit_inline():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import pytest
from jaxtyping import TypeCheckError
from plum import NotFoundLookupError

import equinox as eqx
import quax
import quax.examples.lora as lora
from jaxtyping import TypeCheckError


def test_linear(getkey):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_named.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import cast

import equinox as eqx
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import pytest
from jaxtyping import Array

import equinox as eqx
import quax
import quax.examples.named as named
from jaxtyping import Array


def test_init(getkey):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sparse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

import equinox as eqx
import quax
import quax.examples.sparse as sparse

Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import pytest

import equinox as eqx
import quax
import quax.examples.zero as zero

Expand Down