diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index da61833bbe5b..626584bc3a20 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -6209,11 +6209,16 @@ class PolyTranslator(TypeTranslator): See docstring for apply_poly() for details. """ - def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None: + def __init__( + self, + poly_tvars: Iterable[TypeVarLikeType], + bound_tvars: frozenset[TypeVarLikeType] = frozenset(), + seen_aliases: frozenset[TypeInfo] = frozenset(), + ) -> None: self.poly_tvars = set(poly_tvars) # This is a simplified version of TypeVarScope used during semantic analysis. - self.bound_tvars: set[TypeVarLikeType] = set() - self.seen_aliases: set[TypeInfo] = set() + self.bound_tvars = bound_tvars + self.seen_aliases = seen_aliases def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: found_vars = [] @@ -6289,10 +6294,11 @@ def visit_instance(self, t: Instance) -> Type: if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: if t.type in self.seen_aliases: raise PolyTranslationError() - self.seen_aliases.add(t.type) call = find_member("__call__", t, t, is_operator=True) assert call is not None - return call.accept(self) + return call.accept( + PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type}) + ) return super().visit_instance(t) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 6c98ba2088b1..953855e502d6 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3788,3 +3788,28 @@ def func2(arg: T) -> List[Union[T, str]]: reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]" reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]" [builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallbackProtoMultiple] +from typing import Callable, Protocol, TypeVar +from typing_extensions import Concatenate, ParamSpec + +V_co = TypeVar("V_co", covariant=True) +class Metric(Protocol[V_co]): + def __call__(self) -> V_co: ... + +T = TypeVar("T") +P = ParamSpec("P") +def simple_metric(func: Callable[Concatenate[int, P], T]) -> Callable[P, T]: ... + +@simple_metric +def Negate(count: int, /, metric: Metric[float]) -> float: ... +@simple_metric +def Combine(count: int, m1: Metric[T], m2: Metric[T], /, *more: Metric[T]) -> T: ... + +reveal_type(Negate) # N: Revealed type is "def (metric: __main__.Metric[builtins.float]) -> builtins.float" +reveal_type(Combine) # N: Revealed type is "def [T] (def () -> T`4, def () -> T`4, *more: def () -> T`4) -> T`4" + +def m1() -> float: ... +def m2() -> float: ... +reveal_type(Combine(m1, m2)) # N: Revealed type is "builtins.float" +[builtins fixtures/list.pyi]