Skip to content

Commit

Permalink
Fix inference bug
Browse files Browse the repository at this point in the history
  • Loading branch information
A5rocks committed Feb 25, 2023
1 parent 86993a0 commit 8642f3d
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 35 deletions.
20 changes: 20 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,26 @@ def infer_function_type_arguments(
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
)

return_type = get_proper_type(callee_type.ret_type)
if isinstance(return_type, CallableType):
# fixup:
# def [T] () -> def (T) -> T
# into
# def () -> def [T] (T) -> T
for i, argument in enumerate(inferred_args):
if isinstance(get_proper_type(argument), UninhabitedType):
inferred_args[i] = callee_type.variables[i]

# handle multiple type variables
return_type = return_type.copy_modified(
variables=[*return_type.variables, callee_type.variables[i]]
)

callee_type = callee_type.copy_modified(
# am I allowed to assign the get_proper_type'd thing?
ret_type=return_type
)

if (
callee_type.special_sig == "dict"
and len(inferred_args) == 2
Expand Down
63 changes: 43 additions & 20 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, List, Sequence
from typing import TYPE_CHECKING, Iterable, List, Sequence, Union
from typing_extensions import Final

import mypy.subtypes
Expand Down Expand Up @@ -713,26 +713,37 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)


prefix = mapped_arg.prefix
length = len(prefix.arg_types)
if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix.copy_modified(
arg_types=suffix.arg_types[length:],
arg_kinds=suffix.arg_kinds[length:],
arg_names=suffix.arg_names[length:],
)))
res.append(
Constraint(
mapped_arg,
SUPERTYPE_OF,
suffix.copy_modified(
arg_types=suffix.arg_types[length:],
arg_kinds=suffix.arg_kinds[length:],
arg_names=suffix.arg_names[length:],
),
)
)
elif isinstance(suffix, ParamSpecType):
suffix_prefix = suffix.prefix
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix.copy_modified(
prefix=suffix_prefix.copy_modified(
arg_types=suffix_prefix.arg_types[length:],
arg_kinds=suffix_prefix.arg_kinds[length:],
arg_names=suffix_prefix.arg_names[length:]
res.append(
Constraint(
mapped_arg,
SUPERTYPE_OF,
suffix.copy_modified(
prefix=suffix_prefix.copy_modified(
arg_types=suffix_prefix.arg_types[length:],
arg_kinds=suffix_prefix.arg_kinds[length:],
arg_names=suffix_prefix.arg_names[length:],
)
),
)
)))
)
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -947,12 +958,15 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
prefix_len = len(prefix.arg_types)
cactual_ps = cactual.param_spec()

cactual_prefix: Union[Parameters, CallableType]
if cactual_ps:
cactual_prefix = cactual_ps.prefix
else:
cactual_prefix = cactual

max_prefix_len = len([k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)])
max_prefix_len = len(
[k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)]
)
prefix_len = min(prefix_len, max_prefix_len)

# we could check the prefixes match here, but that should be caught elsewhere.
Expand All @@ -970,13 +984,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
)
else:
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps.copy_modified(
prefix=cactual_prefix.copy_modified(
arg_types=cactual_prefix.arg_types[prefix_len:],
arg_kinds=cactual_prefix.arg_kinds[prefix_len:],
arg_names=cactual_prefix.arg_names[prefix_len:]
# guaranteed due to if conditions
assert isinstance(cactual_prefix, Parameters)

res.append(
Constraint(
param_spec,
SUBTYPE_OF,
cactual_ps.copy_modified(
prefix=cactual_prefix.copy_modified(
arg_types=cactual_prefix.arg_types[prefix_len:],
arg_kinds=cactual_prefix.arg_kinds[prefix_len:],
arg_names=cactual_prefix.arg_names[prefix_len:],
)
),
)
)))
)

# compare prefixes
cactual_prefix = cactual.copy_modified(
Expand Down
2 changes: 1 addition & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return t.prefix.copy_modified(
arg_types=t.prefix.arg_types + [self.replacement, self.replacement],
arg_kinds=t.prefix.arg_kinds + [ARG_STAR, ARG_STAR2],
arg_names=t.prefix.arg_names + [None, None]
arg_names=t.prefix.arg_names + [None, None],
)
return t

Expand Down
1 change: 0 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def freshen_function_type_vars(callee: F) -> F:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
Expand Down
1 change: 0 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,7 +2515,6 @@ def __init__(
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
) -> None:
super().__init__(name, fullname, upper_bound, variance)
assert isinstance(upper_bound, (mypy.types.CallableType, mypy.types.Parameters))

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_paramspec_expr(self)
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5606,7 +5606,7 @@ def top_caller(self) -> Parameters:
return Parameters(
arg_types=[self.object_type(), self.object_type()],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None]
arg_names=[None, None],
)

def str_type(self) -> Instance:
Expand Down
5 changes: 3 additions & 2 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,9 @@ def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> str:
a += ["Variance(COVARIANT)"]
if o.variance == mypy.nodes.CONTRAVARIANT:
a += ["Variance(CONTRAVARIANT)"]
if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
a += [f"UpperBound({o.upper_bound})"]
# ParamSpecs do not have upper bounds!!! (should this be left for future proofing?)
# if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
# a += [f"UpperBound({o.upper_bound})"]
return self.dump(a, o)

def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> str:
Expand Down
5 changes: 2 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,6 @@ def __init__(
super().__init__(name, fullname, id, upper_bound, line=line, column=column)
self.flavor = flavor
self.prefix = prefix or Parameters([], [], [])
assert flavor != ParamSpecFlavor.BARE or isinstance(upper_bound, (CallableType, Parameters))

@staticmethod
def new_unification_variable(old: ParamSpecType) -> ParamSpecType:
Expand Down Expand Up @@ -1995,8 +1994,8 @@ def param_spec(self) -> ParamSpecType | None:
upper_bound=Parameters(
arg_types=[any_type, any_type],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None]
)
arg_names=[None, None],
),
)

def expand_param_spec(
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -2956,8 +2956,11 @@ T = TypeVar('T')

def f(x: Optional[T] = None) -> Callable[..., T]: ...

x = f() # E: Need type annotation for "x"
# TODO: should this warn about needed an annotation? This behavior still _works_...
x = f()
reveal_type(x) # N: Revealed type is "def [T] (*Any, **Any) -> T`1"
y = x
reveal_type(y) # N: Revealed type is "def [T] (*Any, **Any) -> T`1"

[case testDontNeedAnnotationForCallable]
from typing import TypeVar, Optional, Callable, NoReturn
Expand Down
8 changes: 3 additions & 5 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1552,9 +1552,7 @@ class Example(Generic[P]):
def test(ex: Example[P]) -> Example[Concatenate[int, P]]:
...

ex: Example[int] = test()(reveal_type(Example())) # N: Revealed type is "__main__.Example[<nothing>]"
# TODO: fix
reveal_type(test()(Example[int]())) # N: Revealed type is "__main__.Example[<nothing>]" \
# E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[<nothing>]"
ex = test()(Example[int]()) # E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[<nothing>]"
ex: Example[int] = test()(reveal_type(Example())) # N: Revealed type is "__main__.Example[[]]"
reveal_type(test()(Example[int]())) # N: Revealed type is "__main__.Example[[builtins.int, builtins.int]]"
ex = test()(Example[int]()) # E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[[]]"
[builtins fixtures/paramspec.pyi]

0 comments on commit 8642f3d

Please sign in to comment.