Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow field overrides via Annotated #604

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/cattrs/gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ._consts import AttributeOverride, already_generating, neutral
from ._generics import generate_mapping
from ._lc import generate_unique_filename
from ._shared import find_structure_handler
from ._shared import find_structure_handler, get_fields_annotated_by

if TYPE_CHECKING:
from ..converters import BaseConverter
Expand Down Expand Up @@ -260,6 +260,10 @@ def make_dict_unstructure_fn(

working_set.add(cl)

# Merge overrides provided via Annotated with kwargs
annotated_overrides = get_fields_annotated_by(cl, AttributeOverride)
annotated_overrides.update(kwargs)

try:
return make_dict_unstructure_fn_from_attrs(
attrs,
Expand All @@ -270,7 +274,7 @@ def make_dict_unstructure_fn(
_cattrs_use_linecache=_cattrs_use_linecache,
_cattrs_use_alias=_cattrs_use_alias,
_cattrs_include_init_false=_cattrs_include_init_false,
**kwargs,
**annotated_overrides,
)
finally:
working_set.remove(cl)
Expand Down
35 changes: 34 additions & 1 deletion src/cattrs/gen/_shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar, get_type_hints

from attrs import NOTHING, Attribute, Factory

Expand All @@ -12,6 +12,8 @@
if TYPE_CHECKING:
from ..converters import BaseConverter

T = TypeVar("T")


def find_structure_handler(
a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False
Expand Down Expand Up @@ -62,3 +64,34 @@ def handler(v, _, _h=handler):
except RecursionError:
# This means we're dealing with a reference cycle, so use late binding.
return c.structure


def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]:
type_hints = get_type_hints(cls, include_extras=True)
# Support for both AttributeOverride and AttributeOverride()
annotation_type_ = (
annotation_type if isinstance(annotation_type, type) else type(annotation_type)
)

# First pass of filtering to get only fields with annotations
fields_with_annotations = (
(field_name, param_spec.__metadata__)
for field_name, param_spec in type_hints.items()
if hasattr(param_spec, "__metadata__")
)

# Now that we have fields with ANY annotations, we need to remove unwanted annotations.
fields_with_specific_annotation = (
(
field_name,
next((a for a in annotations if isinstance(a, annotation_type_)), None),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will only find annotations that were instantiated, which works fine for the purpose of finding annotations created by gen.override()

)
for field_name, annotations in fields_with_annotations
)

# We still might have some `None` values from previous filtering.
return {
field_name: annotation
for field_name, annotation in fields_with_specific_annotation
if annotation
}
77 changes: 77 additions & 0 deletions tests/test_annotated_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Annotated, Union

import attrs
import pytest

from cattrs.gen._shared import get_fields_annotated_by


class NotThere: ...


class IgnoreMe:
def __init__(self, why: Union[str, None] = None):
self.why = why


class FindMe:
def __init__(self, taint: str):
self.taint = taint


class EmptyClassExample:
pass


class PureClassExample:
id: Annotated[int, FindMe("red")]
name: Annotated[str, FindMe]


class MultipleAnnotationsExample:
id: Annotated[int, FindMe("red"), IgnoreMe()]
name: Annotated[str, IgnoreMe()]
surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")]


@attrs.define
class AttrsClassExample:
id: int = attrs.field(default=0)
color: Annotated[str, FindMe("blue")] = attrs.field(default="red")
config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict)


class PureClassInheritanceExample(PureClassExample):
include: dict
exclude: Annotated[dict, FindMe("boring things")]
extras: Annotated[dict, FindMe]


@pytest.mark.parametrize(
"klass,expected",
[
(EmptyClassExample, {}),
(PureClassExample, {"id": isinstance}),
(AttrsClassExample, {"color": isinstance, "config": isinstance}),
(MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}),
(PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}),
],
)
@pytest.mark.parametrize("instantiate", [True, False])
def test_gets_annotated_types(klass, expected, instantiate: bool):
annotated = get_fields_annotated_by(
klass, FindMe("irrelevant") if instantiate else FindMe
)

assert set(annotated.keys()) == set(
expected.keys()
), "Too many or too few annotations"
assert all(
assertion_func(annotated[field_name], FindMe)
for field_name, assertion_func in expected.items()
), "Unexpected type of annotation"


def test_empty_result_for_missing_annotation():
annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere)
assert not annotated, "No annotation should be found."
Loading