diff --git a/src/cattrs/gen/__init__.py b/src/cattrs/gen/__init__.py index 7a562c47..a81e2561 100644 --- a/src/cattrs/gen/__init__.py +++ b/src/cattrs/gen/__init__.py @@ -33,7 +33,7 @@ from ._consts import AttributeOverride, already_generating, neutral from ._generics import generate_mapping from ._lc import generate_unique_filename -from ._shared import find_structure_handler +from ._shared import find_structure_handler, get_fields_annotated_by if TYPE_CHECKING: from ..converters import BaseConverter @@ -260,6 +260,10 @@ def make_dict_unstructure_fn( working_set.add(cl) + # Merge overrides provided via Annotated with kwargs + annotated_overrides = get_fields_annotated_by(cl, AttributeOverride) + annotated_overrides.update(kwargs) + try: return make_dict_unstructure_fn_from_attrs( attrs, @@ -270,7 +274,7 @@ def make_dict_unstructure_fn( _cattrs_use_linecache=_cattrs_use_linecache, _cattrs_use_alias=_cattrs_use_alias, _cattrs_include_init_false=_cattrs_include_init_false, - **kwargs, + **annotated_overrides, ) finally: working_set.remove(cl) diff --git a/src/cattrs/gen/_shared.py b/src/cattrs/gen/_shared.py index 904c7744..809e9b45 100644 --- a/src/cattrs/gen/_shared.py +++ b/src/cattrs/gen/_shared.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar, get_type_hints from attrs import NOTHING, Attribute, Factory @@ -12,6 +12,8 @@ if TYPE_CHECKING: from ..converters import BaseConverter +T = TypeVar("T") + def find_structure_handler( a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False @@ -62,3 +64,34 @@ def handler(v, _, _h=handler): except RecursionError: # This means we're dealing with a reference cycle, so use late binding. return c.structure + + +def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]: + type_hints = get_type_hints(cls, include_extras=True) + # Support for both AttributeOverride and AttributeOverride() + annotation_type_ = ( + annotation_type if isinstance(annotation_type, type) else type(annotation_type) + ) + + # First pass of filtering to get only fields with annotations + fields_with_annotations = ( + (field_name, param_spec.__metadata__) + for field_name, param_spec in type_hints.items() + if hasattr(param_spec, "__metadata__") + ) + + # Now that we have fields with ANY annotations, we need to remove unwanted annotations. + fields_with_specific_annotation = ( + ( + field_name, + next((a for a in annotations if isinstance(a, annotation_type_)), None), + ) + for field_name, annotations in fields_with_annotations + ) + + # We still might have some `None` values from previous filtering. + return { + field_name: annotation + for field_name, annotation in fields_with_specific_annotation + if annotation + } diff --git a/tests/test_annotated_overrides.py b/tests/test_annotated_overrides.py new file mode 100644 index 00000000..7a9846e4 --- /dev/null +++ b/tests/test_annotated_overrides.py @@ -0,0 +1,77 @@ +from typing import Annotated, Union + +import attrs +import pytest + +from cattrs.gen._shared import get_fields_annotated_by + + +class NotThere: ... + + +class IgnoreMe: + def __init__(self, why: Union[str, None] = None): + self.why = why + + +class FindMe: + def __init__(self, taint: str): + self.taint = taint + + +class EmptyClassExample: + pass + + +class PureClassExample: + id: Annotated[int, FindMe("red")] + name: Annotated[str, FindMe] + + +class MultipleAnnotationsExample: + id: Annotated[int, FindMe("red"), IgnoreMe()] + name: Annotated[str, IgnoreMe()] + surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")] + + +@attrs.define +class AttrsClassExample: + id: int = attrs.field(default=0) + color: Annotated[str, FindMe("blue")] = attrs.field(default="red") + config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict) + + +class PureClassInheritanceExample(PureClassExample): + include: dict + exclude: Annotated[dict, FindMe("boring things")] + extras: Annotated[dict, FindMe] + + +@pytest.mark.parametrize( + "klass,expected", + [ + (EmptyClassExample, {}), + (PureClassExample, {"id": isinstance}), + (AttrsClassExample, {"color": isinstance, "config": isinstance}), + (MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}), + (PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}), + ], +) +@pytest.mark.parametrize("instantiate", [True, False]) +def test_gets_annotated_types(klass, expected, instantiate: bool): + annotated = get_fields_annotated_by( + klass, FindMe("irrelevant") if instantiate else FindMe + ) + + assert set(annotated.keys()) == set( + expected.keys() + ), "Too many or too few annotations" + assert all( + assertion_func(annotated[field_name], FindMe) + for field_name, assertion_func in expected.items() + ), "Unexpected type of annotation" + + +def test_empty_result_for_missing_annotation(): + annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere) + assert not annotated, "No annotation should be found."