Skip to content

Commit

Permalink
fix(typing): Resolve multiple @utils.use_signature issues (#3565)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Sep 6, 2024
1 parent 111d6e7 commit a7c227b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 40 deletions.
91 changes: 62 additions & 29 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
from copy import deepcopy
from itertools import groupby
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
Literal,
TypeVar,
cast,
overload,
)

import jsonschema
import narwhals.stable.v1 as nw
Expand All @@ -22,13 +31,13 @@
from altair.utils.schemapi import SchemaBase, Undefined

if sys.version_info >= (3, 12):
from typing import Protocol, runtime_checkable
from typing import Protocol, TypeAliasType, runtime_checkable
else:
from typing_extensions import Protocol, runtime_checkable
from typing_extensions import Protocol, TypeAliasType, runtime_checkable
if sys.version_info >= (3, 10):
from typing import ParamSpec
from typing import Concatenate, ParamSpec
else:
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec


if TYPE_CHECKING:
Expand All @@ -40,9 +49,21 @@
from altair.utils._dfi_types import DataFrame as DfiDataFrame
from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType

V = TypeVar("V")
P = ParamSpec("P")
TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame)
T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")

WrapsFunc = TypeAliasType("WrapsFunc", Callable[..., R], type_params=(R,))
WrappedFunc = TypeAliasType("WrappedFunc", Callable[P, R], type_params=(P, R))
# NOTE: Requires stringized form to avoid `< (3, 11)` issues
# See: /~https://github.com/vega/altair/actions/runs/10667859416/job/29567290871?pr=3565
WrapsMethod = TypeAliasType(
"WrapsMethod", "Callable[Concatenate[T, ...], R]", type_params=(T, R)
)
WrappedMethod = TypeAliasType(
"WrappedMethod", Callable[Concatenate[T, P], R], type_params=(T, P, R)
)


@runtime_checkable
Expand Down Expand Up @@ -708,31 +729,43 @@ def infer_vegalite_type_for_narwhals(
raise ValueError(msg)


def use_signature(obj: Callable[P, Any]): # -> Callable[..., Callable[P, V]]:
"""Apply call signature and documentation of `obj` to the decorated method."""
def use_signature(tp: Callable[P, Any], /):
"""
Use the signature and doc of ``tp`` for the decorated callable ``cb``.
def decorate(func: Callable[..., V]) -> Callable[P, V]:
# call-signature of func is exposed via __wrapped__.
# we want it to mimic obj.__init__
- **Overload 1**: Decorating method
- **Overload 2**: Decorating function
# error: Accessing "__init__" on an instance is unsound,
# since instance.__init__ could be from an incompatible subclass [misc]
wrapped = (
obj.__init__ if (isinstance(obj, type) and issubclass(obj, object)) else obj # type: ignore [misc]
)
func.__wrapped__ = wrapped # type: ignore[attr-defined]
func._uses_signature = obj # type: ignore[attr-defined]

# Supplement the docstring of func with information from obj
if doc_in := obj.__doc__:
doc_lines = doc_in.splitlines(keepends=True)[1:]
# Patch in a reference to the class this docstring is copied from,
# to generate a hyperlink.
line_1 = f"{func.__doc__ or f'Refer to :class:`{obj.__name__}`'}\n"
func.__doc__ = "".join((line_1, *doc_lines))
return func
Returns
-------
**Adding the annotation breaks typing**:
Overload[Callable[[WrapsMethod[T, R]], WrappedMethod[T, P, R]], Callable[[WrapsFunc[R]], WrappedFunc[P, R]]]
"""

@overload
def decorate(cb: WrapsMethod[T, R], /) -> WrappedMethod[T, P, R]: ...

@overload
def decorate(cb: WrapsFunc[R], /) -> WrappedFunc[P, R]: ...

def decorate(cb: WrapsFunc[R], /) -> WrappedMethod[T, P, R] | WrappedFunc[P, R]:
"""
Raises when no doc was found.
Notes
-----
- Reference to ``tp`` is stored in ``cb.__wrapped__``.
- The doc for ``cb`` will have a ``.rst`` link added, referring to ``tp``.
"""
cb.__wrapped__ = getattr(tp, "__init__", tp) # type: ignore[attr-defined]

if doc_in := tp.__doc__:
line_1 = f"{cb.__doc__ or f'Refer to :class:`{tp.__name__}`'}\n"
cb.__doc__ = "".join((line_1, *doc_in.splitlines(keepends=True)[1:]))
return cb
else:
msg = f"Found no doc for {obj!r}"
msg = f"Found no doc for {tp!r}"
raise AttributeError(msg)

return decorate
Expand Down
34 changes: 25 additions & 9 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,26 @@ def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]:


class FacetMapping(core.FacetMapping):
"""
FacetMapping schema wrapper.
Parameters
----------
column : str, :class:`FacetFieldDef`, :class:`Column`
A field definition for the horizontal facet of trellis plots.
row : str, :class:`FacetFieldDef`, :class:`Row`
A field definition for the vertical facet of trellis plots.
"""

_class_is_valid_at_instantiation = False

@utils.use_signature(core.FacetMapping)
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(
self,
column: Optional[str | FacetFieldDef | Column] = Undefined,
row: Optional[str | FacetFieldDef | Row] = Undefined,
**kwargs: Any,
) -> None:
super().__init__(column=column, row=row, **kwargs) # type: ignore[arg-type]

def to_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
copy = self.copy(deep=False)
Expand Down Expand Up @@ -3606,13 +3621,14 @@ def facet(
self = _top_schema_base(self).copy(deep=False)
data, self.data = self.data, Undefined

if facet_specified:
f: Facet | FacetMapping
if not utils.is_undefined(facet):
f = channels.Facet(facet) if isinstance(facet, str) else facet
else:
r: Any = row
f = FacetMapping(row=r, column=column)

return FacetChart(spec=self, facet=f, data=data, columns=columns, **kwargs)
return FacetChart(spec=self, facet=f, data=data, columns=columns, **kwargs) # pyright: ignore[reportArgumentType]


class Chart(
Expand Down Expand Up @@ -4162,7 +4178,7 @@ def add_selection(self, *selections) -> Self: # noqa: ANN002

def concat(*charts: ConcatType, **kwargs: Any) -> ConcatChart:
"""Concatenate charts horizontally."""
return ConcatChart(concat=charts, **kwargs) # pyright: ignore
return ConcatChart(concat=charts, **kwargs)


class HConcatChart(TopLevelMixin, core.TopLevelHConcatSpec):
Expand Down Expand Up @@ -4266,7 +4282,7 @@ def add_selection(self, *selections) -> Self: # noqa: ANN002

def hconcat(*charts: ConcatType, **kwargs: Any) -> HConcatChart:
"""Concatenate charts horizontally."""
return HConcatChart(hconcat=charts, **kwargs) # pyright: ignore
return HConcatChart(hconcat=charts, **kwargs)


class VConcatChart(TopLevelMixin, core.TopLevelVConcatSpec):
Expand Down Expand Up @@ -4372,7 +4388,7 @@ def add_selection(self, *selections) -> Self: # noqa: ANN002

def vconcat(*charts: ConcatType, **kwargs: Any) -> VConcatChart:
"""Concatenate charts vertically."""
return VConcatChart(vconcat=charts, **kwargs) # pyright: ignore
return VConcatChart(vconcat=charts, **kwargs)


class LayerChart(TopLevelMixin, _EncodingMixin, core.TopLevelLayerSpec):
Expand Down Expand Up @@ -4498,7 +4514,7 @@ def add_selection(self, *selections) -> Self: # noqa: ANN002

def layer(*charts: LayerType, **kwargs: Any) -> LayerChart:
"""Layer multiple charts."""
return LayerChart(layer=charts, **kwargs) # pyright: ignore
return LayerChart(layer=charts, **kwargs)


class FacetChart(TopLevelMixin, core.TopLevelFacetSpec):
Expand Down
1 change: 0 additions & 1 deletion altair/vegalite/v5/schema/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@
"ExtentTransform",
"FacetEncodingFieldDef",
"FacetFieldDef",
"FacetMapping",
"FacetSpec",
"FacetedEncoding",
"FacetedUnitSpec",
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str:
# of exported classes which are also defined in the channels or api modules which takes
# precedent in the generated __init__.py files one and two levels up.
# Importing these classes from multiple modules confuses type checkers.
EXCLUDE = {"Color", "Text", "LookupData", "Dict"}
EXCLUDE = {"Color", "Text", "LookupData", "Dict", "FacetMapping"}
it = (c for c in definitions.keys() - EXCLUDE if not c.startswith("_"))
all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"]

Expand Down

0 comments on commit a7c227b

Please sign in to comment.