From 9f838ab7550df578c3788c7d33c8c4d96213f7ba Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Sat, 25 Feb 2023 22:45:18 -0500 Subject: [PATCH] Make typing_extensions a dev-dependency --- jax/_src/api.py | 9 ++++++--- jax/_src/stages.py | 11 ++++++++--- setup.py | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 2becf2fc5a6b..c7dcfad33a3a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -28,11 +28,13 @@ from functools import partial import inspect import math -from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload +from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast, + overload, TYPE_CHECKING) import weakref import numpy as np -from typing_extensions import ParamSpec +if TYPE_CHECKING: + from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -94,7 +96,8 @@ T = TypeVar("T") U = TypeVar("U") V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") map, unsafe_map = safe_map, map diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 0e050817b0e7..2bb677ec40f1 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,8 +32,10 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union -from typing_extensions import ParamSpec +from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union, + TYPE_CHECKING) +if TYPE_CHECKING: + from typing_extensions import ParamSpec import jax @@ -648,7 +650,10 @@ def cost_analysis(self) -> Any | None: V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") class Wrapped(Protocol, Generic[P, V_co]): diff --git a/setup.py b/setup.py index 577f321d8e56..e85900f253c6 100644 --- a/setup.py +++ b/setup.py @@ -84,9 +84,9 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', - 'typing_extensions>=4.5.0', ], extras_require={ + 'dev': ['typing_extensions>=4.8.0'], # Minimum jaxlib version; used in testing. 'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],