Skip to content

Commit

Permalink
Update some type inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Oct 9, 2024
1 parent 1861342 commit df28a74
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from functools import reduce
from inspect import Parameter, Signature, isabstract, isclass
from re import Pattern
from types import FunctionType
from types import FunctionType, GenericAlias
from typing import (
Any,
AnyStr,
Expand Down Expand Up @@ -1326,6 +1326,13 @@ def from_type_guarded(thing):
strategy = as_strategy(types._global_type_lookup[thing], thing)
if strategy is not NotImplemented:
return strategy
elif (
isinstance(thing, GenericAlias)
and (to := get_origin(thing)) in types._global_type_lookup
):
strategy = as_strategy(types._global_type_lookup[to], thing)
if strategy is not NotImplemented:
return strategy
except TypeError: # pragma: no cover
# This was originally due to a bizarre divergence in behaviour on Python 3.9.0:
# typing.Callable[[], foo] has __args__ = (foo,) but collections.abc.Callable
Expand Down
48 changes: 27 additions & 21 deletions hypothesis-python/src/hypothesis/strategies/_internal/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import uuid
import warnings
import zoneinfo
from collections.abc import Iterator
from functools import partial
from pathlib import PurePath
from types import FunctionType
from typing import TYPE_CHECKING, Any, Iterator, Tuple, get_args, get_origin
from typing import TYPE_CHECKING, Any, get_args, get_origin

from hypothesis import strategies as st
from hypothesis.errors import HypothesisWarning, InvalidArgument, ResolutionFailed
Expand Down Expand Up @@ -339,7 +340,7 @@ def get_constraints_filter_map():
return {} # pragma: no cover


def _get_constraints(args: Tuple[Any, ...]) -> Iterator["at.BaseMetadata"]:
def _get_constraints(args: tuple[Any, ...]) -> Iterator["at.BaseMetadata"]:
at = sys.modules.get("annotated_types")
for arg in args:
if at and isinstance(arg, at.BaseMetadata):
Expand Down Expand Up @@ -619,7 +620,7 @@ def _networks(bits):
# exposed for it, and NotImplemented itself is typed as Any so that it can be
# returned without being listed in a function signature:
# /~https://github.com/python/mypy/issues/6710#issuecomment-485580032
_global_type_lookup: typing.Dict[
_global_type_lookup: dict[
type, typing.Union[st.SearchStrategy, typing.Callable[[type], st.SearchStrategy]]
] = {
type(None): st.none(),
Expand Down Expand Up @@ -726,8 +727,8 @@ def _networks(bits):
_global_type_lookup[builtins.sequenceiterator] = st.builds(iter, st.tuples()) # type: ignore


_global_type_lookup[type] = st.sampled_from(
[type(None), *sorted(_global_type_lookup, key=str)]
_fallback_type_strategy = st.sampled_from(
sorted(_global_type_lookup, key=type_sorting_key)
)
# subclass of MutableMapping, and so we resolve to a union which
# includes this... but we don't actually ever want to build one.
Expand Down Expand Up @@ -803,15 +804,15 @@ def _networks(bits):
# installed. To avoid the performance hit of importing anything here, we defer
# it until the method is called the first time, at which point we replace the
# entry in the lookup table with the direct call.
def _from_numpy_type(thing: typing.Type) -> typing.Optional[st.SearchStrategy]:
def _from_numpy_type(thing: type) -> typing.Optional[st.SearchStrategy]:
from hypothesis.extra.numpy import _from_type

_global_extra_lookup["numpy"] = _from_type
return _from_type(thing)


_global_extra_lookup: typing.Dict[
str, typing.Callable[[typing.Type], typing.Optional[st.SearchStrategy]]
_global_extra_lookup: dict[
str, typing.Callable[[type], typing.Optional[st.SearchStrategy]]
] = {
"numpy": _from_numpy_type,
}
Expand Down Expand Up @@ -839,26 +840,30 @@ def really_inner(thing):
return fallback
return func(thing)

_global_type_lookup[type_] = really_inner
_global_type_lookup[get_origin(type_) or type_] = really_inner
return really_inner

return inner


@register(typing.Type)
@register(type)
@register("Type")
@register("Type", module=typing_extensions)
def resolve_Type(thing):
if getattr(thing, "__args__", None) is None:
return st.just(type)
elif get_args(thing) == (): # pragma: no cover
return _fallback_type_strategy
args = (thing.__args__[0],)
if is_a_union(args[0]):
args = args[0].__args__
# Duplicate check from from_type here - only paying when needed.
args = list(args)
for i, a in enumerate(args):
if type(a) == typing.ForwardRef:
if type(a) in (typing.ForwardRef, str):
try:
args[i] = getattr(builtins, a.__forward_arg__)
args[i] = getattr(builtins, getattr(a, "__forward_arg__", a))
except AttributeError:
raise ResolutionFailed(
f"Cannot find the type referenced by {thing} - try using "
Expand All @@ -867,12 +872,12 @@ def resolve_Type(thing):
return st.sampled_from(sorted(args, key=type_sorting_key))


@register(typing.List, st.builds(list))
@register("List", st.builds(list))
def resolve_List(thing):
return st.lists(st.from_type(thing.__args__[0]))


@register(typing.Tuple, st.builds(tuple))
@register("Tuple", st.builds(tuple))
def resolve_Tuple(thing):
elem_types = getattr(thing, "__args__", None) or ()
if len(elem_types) == 2 and elem_types[-1] is Ellipsis:
Expand Down Expand Up @@ -906,27 +911,28 @@ def _from_hashable_type(type_):
return st.from_type(type_).filter(_can_hash)


@register(typing.Set, st.builds(set))
@register("Set", st.builds(set))
@register(typing.MutableSet, st.builds(set))
def resolve_Set(thing):
return st.sets(_from_hashable_type(thing.__args__[0]))


@register(typing.FrozenSet, st.builds(frozenset))
@register("FrozenSet", st.builds(frozenset))
def resolve_FrozenSet(thing):
return st.frozensets(_from_hashable_type(thing.__args__[0]))


@register(typing.Dict, st.builds(dict))
@register("Dict", st.builds(dict))
def resolve_Dict(thing):
# If thing is a Collection instance, we need to fill in the values
keys_vals = thing.__args__ * 2
keys, vals, *_ = thing.__args__ * 2
return st.dictionaries(
_from_hashable_type(keys_vals[0]), st.from_type(keys_vals[1])
_from_hashable_type(keys),
st.none() if vals is None else st.from_type(vals),
)


@register(typing.DefaultDict, st.builds(collections.defaultdict))
@register("DefaultDict", st.builds(collections.defaultdict))
@register("DefaultDict", st.builds(collections.defaultdict), module=typing_extensions)
def resolve_DefaultDict(thing):
return resolve_Dict(thing).map(lambda d: collections.defaultdict(None, d))
Expand Down Expand Up @@ -988,9 +994,9 @@ def resolve_Pattern(thing):
return st.just(re.compile(thing.__args__[0]()))


@register( # pragma: no branch # coverage does not see lambda->exit branch
@register(
typing.Match,
st.text().map(lambda c: re.match(".", c, flags=re.DOTALL)).filter(bool),
st.text().map(partial(re.match, ".", flags=re.DOTALL)).filter(bool),
)
def resolve_Match(thing):
if thing.__args__[0] == bytes:
Expand Down
85 changes: 58 additions & 27 deletions hypothesis-python/tests/cover/test_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
),
key=str,
)
_Type = getattr(typing, "Type", None)
_List = getattr(typing, "List", None)
_Dict = getattr(typing, "Dict", None)
_Set = getattr(typing, "Set", None)
_FrozenSet = getattr(typing, "FrozenSet", None)
_Tuple = getattr(typing, "Tuple", None)


@pytest.mark.parametrize("typ", generics, ids=repr)
Expand Down Expand Up @@ -104,10 +110,13 @@ def test_specialised_scalar_types(data, typ, instance_of):


def test_typing_Type_int():
assert_simple_property(from_type(typing.Type[int]), lambda x: x is int)
for t in (type[int], type["int"], _Type[int], _Type["int"]):
assert_simple_property(from_type(t), lambda x: x is int)


@given(from_type(typing.Type[typing.Union[str, list]]))
@given(
from_type(type[typing.Union[str, list]]) | from_type(_Type[typing.Union[str, list]])
)
def test_typing_Type_Union(ex):
assert ex in (str, list)

Expand Down Expand Up @@ -143,15 +152,21 @@ class Elem:
@pytest.mark.parametrize(
"typ,coll_type",
[
(typing.Set[Elem], set),
(typing.FrozenSet[Elem], frozenset),
(typing.Dict[Elem, None], dict),
(_Set[Elem], set),
(_FrozenSet[Elem], frozenset),
(_Dict[Elem, None], dict),
(set[Elem], set),
(frozenset[Elem], frozenset),
# (dict[Elem, None], dict), # FIXME this should work
(typing.DefaultDict[Elem, None], collections.defaultdict),
(typing.KeysView[Elem], type({}.keys())),
(typing.ValuesView[Elem], type({}.values())),
(typing.List[Elem], list),
(typing.Tuple[Elem], tuple),
(typing.Tuple[Elem, ...], tuple),
(_List[Elem], list),
(_Tuple[Elem], tuple),
(_Tuple[Elem, ...], tuple),
(list[Elem], list),
(tuple[Elem], tuple),
(tuple[Elem, ...], tuple),
(typing.Iterator[Elem], typing.Iterator),
(typing.Sequence[Elem], typing.Sequence),
(typing.Iterable[Elem], typing.Iterable),
Expand Down Expand Up @@ -226,23 +241,24 @@ def test_Optional_minimises_to_None():
assert minimal(from_type(typing.Optional[int]), lambda ex: True) is None


@pytest.mark.parametrize("n", range(10))
def test_variable_length_tuples(n):
type_ = typing.Tuple[int, ...]
@pytest.mark.parametrize("n", [0, 1, 5])
@pytest.mark.parametrize("t", [tuple, _Tuple])
def test_variable_length_tuples(t, n):
type_ = t[int, ...]
check_can_generate_examples(from_type(type_).filter(lambda ex: len(ex) == n))


def test_lookup_overrides_defaults():
sentinel = object()
with temp_registered(int, st.just(sentinel)):

@given(from_type(typing.List[int]))
@given(from_type(list[int]))
def inner_1(ex):
assert all(elem is sentinel for elem in ex)

inner_1()

@given(from_type(typing.List[int]))
@given(from_type(list[int]))
def inner_2(ex):
assert all(isinstance(elem, int) for elem in ex)

Expand All @@ -253,7 +269,7 @@ def test_register_generic_typing_strats():
# I don't expect anyone to do this, but good to check it works as expected
with temp_registered(
typing.Sequence,
types._global_type_lookup[typing.get_origin(typing.Set) or typing.Set],
types._global_type_lookup[set],
):
# We register sets for the abstract sequence type, which masks subtypes
# from supertype resolution but not direct resolution
Expand All @@ -264,9 +280,7 @@ def test_register_generic_typing_strats():
from_type(typing.Container[int]),
lambda ex: not isinstance(ex, typing.Sequence),
)
assert_all_examples(
from_type(typing.List[int]), lambda ex: isinstance(ex, list)
)
assert_all_examples(from_type(list[int]), lambda ex: isinstance(ex, list))


def if_available(name):
Expand Down Expand Up @@ -587,7 +601,7 @@ def test_override_args_for_namedtuple(thing):
assert thing.a is None


@pytest.mark.parametrize("thing", [typing.Optional, typing.List, typing.Type])
@pytest.mark.parametrize("thing", [typing.Optional, list, type, _List, _Type])
def test_cannot_resolve_bare_forward_reference(thing):
t = thing["ConcreteFoo"]
with pytest.raises(InvalidArgument):
Expand Down Expand Up @@ -740,7 +754,7 @@ def test_resolving_recursive_type_with_registered_constraint_not_none():
find_any(s, lambda s: s.next_node is not None)


@given(from_type(typing.Tuple[()]))
@given(from_type(tuple[()]) | from_type(_Tuple[()]))
def test_resolves_empty_Tuple_issue_1583_regression(ex):
# See e.g. /~https://github.com/python/mypy/commit/71332d58
assert ex == ()
Expand Down Expand Up @@ -805,11 +819,17 @@ def test_cannot_resolve_abstract_class_with_no_concrete_subclass(instance):


@fails_with(ResolutionFailed)
@given(st.from_type(typing.Type["ConcreteFoo"]))
@given(st.from_type(type["ConcreteFoo"]))
def test_cannot_resolve_type_with_forwardref(instance):
raise AssertionError("test body unreachable as strategy cannot resolve")


@fails_with(ResolutionFailed)
@given(st.from_type(_Type["ConcreteFoo"]))
def test_cannot_resolve_type_with_forwardref_old(instance):
raise AssertionError("test body unreachable as strategy cannot resolve")


@pytest.mark.parametrize("typ", [typing.Hashable, typing.Sized])
@given(data=st.data())
def test_inference_on_generic_collections_abc_aliases(typ, data):
Expand Down Expand Up @@ -938,9 +958,12 @@ def test_timezone_lookup(type_):
@pytest.mark.parametrize(
"typ",
[
typing.Set[typing.Hashable],
typing.FrozenSet[typing.Hashable],
typing.Dict[typing.Hashable, int],
_Set[typing.Hashable],
_FrozenSet[typing.Hashable],
_Dict[typing.Hashable, int],
set[typing.Hashable],
frozenset[typing.Hashable],
dict[typing.Hashable, int],
],
)
@settings(suppress_health_check=[HealthCheck.data_too_large])
Expand Down Expand Up @@ -973,7 +996,8 @@ def __init__(self, value=-1) -> None:
"typ,repr_",
[
(int, "integers()"),
(typing.List[str], "lists(text())"),
(list[str], "lists(text())"),
(_List[str], "lists(text())"),
("not a type", "from_type('not a type')"),
(random.Random, "randoms()"),
(_EmptyClass, "from_type(tests.cover.test_lookup._EmptyClass)"),
Expand Down Expand Up @@ -1123,15 +1147,22 @@ def test_resolves_forwardrefs_to_builtin_types(t, data):

@pytest.mark.parametrize("t", BUILTIN_TYPES, ids=lambda t: t.__name__)
def test_resolves_type_of_builtin_types(t):
assert_simple_property(st.from_type(typing.Type[t.__name__]), lambda v: v is t)
assert_simple_property(st.from_type(type[t.__name__]), lambda v: v is t)


@given(st.from_type(typing.Type[typing.Union["str", "int"]]))
@given(
st.from_type(type[typing.Union["str", "int"]])
| st.from_type(_Type[typing.Union["str", "int"]])
)
def test_resolves_type_of_union_of_forwardrefs_to_builtins(x):
assert x in (str, int)


@pytest.mark.parametrize("type_", [typing.List[int], typing.Optional[int]])
@pytest.mark.parametrize(
# Old-style `List` because `list[int]() == list()`, so no need for the hint.
"type_",
[getattr(typing, "List", None)[int], typing.Optional[int]],
)
def test_builds_suggests_from_type(type_):
with pytest.raises(
InvalidArgument, match=re.escape(f"try using from_type({type_!r})")
Expand Down
4 changes: 1 addition & 3 deletions tooling/src/hypothesistooling/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ def warn(msg):

codespell("--write-changes", *files_to_format, *doc_files_to_format)
pip_tool("ruff", "check", "--fix-only", ".")
pip_tool("shed", *files_to_format, *doc_files_to_format)
# FIXME: work through the typing issues and enable py39 formatting
# pip_tool("shed", "--py39-plus", *files_to_format, *doc_files_to_format)
pip_tool("shed", "--py39-plus", *files_to_format, *doc_files_to_format)


VALID_STARTS = (HEADER.split()[0], "#!/usr/bin/env python")
Expand Down

0 comments on commit df28a74

Please sign in to comment.