Skip to content

Commit

Permalink
Use ParamSpec in jit annotation; bump MyPy to 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 26, 2023
1 parent 008f35a commit ba45b45
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: flake8

- repo: /~https://github.com/pre-commit/mirrors-mypy
rev: 'v0.982'
rev: 'v1.0.1'
hooks:
- id: mypy
files: (jax/|tests/typing_test\.py)
Expand Down
25 changes: 15 additions & 10 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
from __future__ import annotations

import collections
from contextlib import contextmanager, ExitStack
import functools
from functools import partial
import inspect
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
overload)

import numpy as np
from contextlib import contextmanager, ExitStack
from typing_extensions import ParamSpec

import jax
from jax._src import linear_util as lu
Expand Down Expand Up @@ -105,6 +107,9 @@
F = TypeVar("F", bound=Callable)
T = TypeVar("T")
U = TypeVar("U")
V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")


map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -155,7 +160,7 @@ def _update_debug_special_thread_local(_):


def jit(
fun: Callable,
fun: Callable[P, V_co],
*,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
Expand All @@ -165,7 +170,7 @@ def jit(
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
) -> stages.Wrapped:
) -> stages.Wrapped[P, V_co]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
Expand Down Expand Up @@ -339,7 +344,7 @@ def _prepare_jit(fun, static_argnums, static_argnames, donate_argnums,
PytreeOfAbstractedAxesSpec = Any

def _python_jit(
fun: Callable,
fun: Callable[P, V_co],
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
Expand All @@ -349,7 +354,7 @@ def _python_jit(
inline: bool,
keep_unused: bool,
abstracted_axes: Optional[PytreeOfAbstractedAxesSpec],
) -> stages.Wrapped:
) -> stages.Wrapped[P, V_co]:
@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
Expand Down Expand Up @@ -483,7 +488,7 @@ def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
return None

def _cpp_jit(
fun: Callable,
fun: Callable[P, V_co],
*,
static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...],
Expand All @@ -492,7 +497,7 @@ def _cpp_jit(
donate_argnums: Tuple[int, ...],
inline: bool,
keep_unused: bool,
) -> stages.Wrapped:
) -> stages.Wrapped[P, V_co]:
# An implementation of `jit` that tries to do as much as possible in C++.
# The goal of this function is to speed up the time it takes to process the
# arguments, find the correct C++ executable, start the transfer of arguments
Expand Down Expand Up @@ -2064,7 +2069,7 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,


def _python_pmap(
fun: Callable,
fun: Callable[P, V_co],
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
Expand All @@ -2075,7 +2080,7 @@ def _python_pmap(
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
) -> stages.Wrapped[P, V_co]:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
import warnings

from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
from typing import (Any, Dict, Generic, List, NamedTuple, Optional, Protocol,
Sequence, Tuple, TypeVar)
from typing_extensions import ParamSpec

import jax
from jax import tree_util
Expand Down Expand Up @@ -617,7 +619,11 @@ def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:
return None


class Wrapped(Protocol):
V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")


class Wrapped(Protocol, Generic[P, V_co]):
"""A function ready to be specialized, lowered, and compiled.
This protocol reflects the output of functions such as
Expand All @@ -626,7 +632,7 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def generate_proto(source):
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions>=4.5.0',
],
extras_require={
# Minimum jaxlib version; used in testing.
Expand Down

0 comments on commit ba45b45

Please sign in to comment.