Skip to content

Commit

Permalink
refactor: Metadata handling (#3721)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Sep 15, 2024
1 parent d4e01f9 commit 0e2ad9a
Show file tree
Hide file tree
Showing 25 changed files with 1,105 additions and 639 deletions.
31 changes: 18 additions & 13 deletions litestar/_openapi/schema_generation/plugins/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@

import msgspec
from msgspec import Struct
from msgspec.structs import fields

from litestar.plugins import OpenAPISchemaPlugin
from litestar.plugins.core._msgspec import kwarg_definition_from_field
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition
from litestar.utils.predicates import is_optional_union

if TYPE_CHECKING:
from msgspec.structs import FieldInfo

from litestar._openapi.schema_generation import SchemaCreator
from litestar.openapi.spec import Schema

Expand All @@ -23,11 +21,25 @@ def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool:
return not field_definition.is_union and field_definition.is_subclass_of(Struct)

def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema:
def is_field_required(field: FieldInfo) -> bool:
def is_field_required(field: msgspec.inspect.Field) -> bool:
return field.required or field.default_factory is Empty

type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True)
struct_fields = fields(field_definition.type_)
struct_info: msgspec.inspect.StructType = msgspec.inspect.type_info(field_definition.type_) # type: ignore[assignment]
struct_fields = struct_info.fields

property_fields = {}
for field in struct_fields:
field_definition_kwargs = {}
if kwarg_definition := kwarg_definition_from_field(field)[0]:
field_definition_kwargs["kwarg_definition"] = kwarg_definition

property_fields[field.encode_name] = FieldDefinition.from_annotation(
annotation=type_hints[field.name],
name=field.encode_name,
default=field.default if field.default not in {msgspec.NODEFAULT, msgspec.UNSET} else Empty,
**field_definition_kwargs,
)

return schema_creator.create_component_schema(
field_definition,
Expand All @@ -38,12 +50,5 @@ def is_field_required(field: FieldInfo) -> bool:
if is_field_required(field=field) and not is_optional_union(type_hints[field.name])
]
),
property_fields={
field.encode_name: FieldDefinition.from_kwarg(
type_hints[field.name],
field.encode_name,
default=field.default if field.default not in {msgspec.NODEFAULT, msgspec.UNSET} else Empty,
)
for field in struct_fields
},
property_fields=property_fields,
)
13 changes: 5 additions & 8 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,12 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
result = self.for_union_field(field_definition)
elif field_definition.is_type_var:
result = self.for_typevar()
elif field_definition.inner_types and not field_definition.is_generic:
result = self.for_object_type(field_definition)
elif self.is_constrained_field(field_definition):
result = self.for_constrained_field(field_definition)
elif field_definition.inner_types and not field_definition.is_generic:
# this case does not recurse for all base cases, so it needs to happen
# after all non-concrete cases
result = self.for_object_type(field_definition)
elif field_definition.is_subclass_of(UploadFile):
result = self.for_upload_file(field_definition)
else:
Expand Down Expand Up @@ -564,12 +566,7 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) ->
if field_definition.inner_types:
items = list(map(item_creator.for_field_definition, field_definition.inner_types))
schema.items = Schema(one_of=items) if len(items) > 1 else items[0]
else:
schema.items = item_creator.for_field_definition(
FieldDefinition.from_kwarg(
field_definition.annotation.item_type, f"{field_definition.annotation.__name__}Field"
)
)
# INFO: Removed because it was only for pydantic constrained collections
return schema

def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
Expand Down
76 changes: 41 additions & 35 deletions litestar/contrib/pydantic/pydantic_dto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing_extensions import Annotated, TypeAlias, override

from litestar.contrib.pydantic.utils import is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2
from litestar.contrib.pydantic.utils import get_model_info, is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2
from litestar.dto.base_dto import AbstractDTO
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field
Expand Down Expand Up @@ -109,50 +109,56 @@ def decode_bytes(self, value: bytes) -> Any:
def generate_field_definitions(
cls, model_type: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel]
) -> Generator[DTOFieldDefinition, None, None]:
model_field_definitions = cls.get_model_type_hints(model_type)
model_info = get_model_info(model_type)
model_fields = model_info.model_fields
model_field_definitions = model_info.field_definitions

model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo]
try:
model_fields = dict(model_type.model_fields) # type: ignore[union-attr]
except AttributeError:
model_fields = {
k: model_field.field_info
for k, model_field in model_type.__fields__.items() # type: ignore[union-attr]
}

for field_name, field_info in model_fields.items():
field_definition = downtype_for_data_transfer(model_field_definitions[field_name])
for field_name, field_definition in model_field_definitions.items():
field_definition = downtype_for_data_transfer(field_definition)
dto_field = extract_dto_field(field_definition, field_definition.extra)

try:
extra = field_info.extra # type: ignore[union-attr]
except AttributeError:
extra = field_info.json_schema_extra # type: ignore[union-attr]

if extra is not None and extra.pop(DTO_FIELD_META_KEY, None):
warn(
message="Declaring 'DTOField' via Pydantic's 'Field.extra' is deprecated. "
"Use 'Annotated', e.g., 'Annotated[str, DTOField(mark='read-only')]' instead. "
"Support for 'DTOField' in 'Field.extra' will be removed in v3.",
category=DeprecationWarning,
stacklevel=2,
default: Any = Empty
default_factory: Any = None
if field_info := model_fields.get(field_name):
# field_info might not exist, since FieldInfo isn't provided by pydantic
# for computed fields, but we still generate a FieldDefinition for them
try:
extra = field_info.extra # type: ignore[union-attr]
except AttributeError:
extra = field_info.json_schema_extra # type: ignore[union-attr]

if extra is not None and extra.pop(DTO_FIELD_META_KEY, None):
warn(
message="Declaring 'DTOField' via Pydantic's 'Field.extra' is deprecated. "
"Use 'Annotated', e.g., 'Annotated[str, DTOField(mark='read-only')]' instead. "
"Support for 'DTOField' in 'Field.extra' will be removed in v3.",
category=DeprecationWarning,
stacklevel=2,
)

if not is_pydantic_undefined(field_info.default):
default = field_info.default
elif field_definition.is_optional:
default = None
else:
default = Empty

default_factory = (
field_info.default_factory
if field_info.default_factory and not is_pydantic_undefined(field_info.default_factory)
else None
)

if not is_pydantic_undefined(field_info.default):
default = field_info.default
elif field_definition.is_optional:
default = None
else:
default = Empty

yield replace(
DTOFieldDefinition.from_field_definition(
field_definition=field_definition,
dto_field=dto_field,
model_name=model_type.__name__,
default_factory=field_info.default_factory
if field_info.default_factory and not is_pydantic_undefined(field_info.default_factory)
else None,
default_factory=default_factory,
# we don't want the constraints to be set on the DTO struct as
# constraints, but as schema metadata only, so we can let pydantic
# handle all the constraining
passthrough_constraints=False,
),
default=default,
name=field_name,
Expand Down
14 changes: 1 addition & 13 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from typing_extensions import Buffer, TypeGuard

from litestar._signature.types import ExtendedMsgSpecValidationError
from litestar.contrib.pydantic.utils import is_pydantic_constrained_field, is_pydantic_v2
from litestar.contrib.pydantic.utils import is_pydantic_v2
from litestar.exceptions import MissingDependencyException
from litestar.plugins import InitPluginProtocol
from litestar.typing import _KWARG_META_EXTRACTORS
from litestar.utils import is_class_and_subclass

try:
Expand Down Expand Up @@ -114,16 +113,6 @@ def is_pydantic_v2_model_class(annotation: Any) -> TypeGuard[type[pydantic_v2.Ba
return is_class_and_subclass(annotation, pydantic_v2.BaseModel)


class ConstrainedFieldMetaExtractor:
@staticmethod
def matches(annotation: Any, name: str | None, default: Any) -> bool:
return is_pydantic_constrained_field(annotation)

@staticmethod
def extract(annotation: Any, default: Any) -> Any:
return [annotation]


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = (
"exclude",
Expand Down Expand Up @@ -292,5 +281,4 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
*(app_config.type_decoders or []),
]

_KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor)
return app_config
83 changes: 9 additions & 74 deletions litestar/contrib/pydantic/pydantic_schema_plugin.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from typing_extensions import Annotated
from typing import TYPE_CHECKING, Any

from litestar.contrib.pydantic.utils import (
create_field_definitions_for_computed_fields,
is_pydantic_2_model,
get_model_info,
is_pydantic_constrained_field,
is_pydantic_model_class,
is_pydantic_undefined,
is_pydantic_v2,
pydantic_get_type_hints_with_generics_resolved,
pydantic_unwrap_and_get_origin,
)
from litestar.exceptions import MissingDependencyException
from litestar.openapi.spec import OpenAPIFormat, OpenAPIType, Schema
from litestar.plugins import OpenAPISchemaPlugin
from litestar.types import Empty
from litestar.typing import FieldDefinition
from litestar.utils import is_class_and_subclass, is_generic
from litestar.utils import is_class_and_subclass

try:
import pydantic as _ # noqa: F401
Expand All @@ -40,6 +33,7 @@

if TYPE_CHECKING:
from litestar._openapi.schema_generation.schema import SchemaCreator
from litestar.typing import FieldDefinition

PYDANTIC_TYPE_MAP: dict[type[Any] | None | Any, Schema] = {
pydantic_v1.ByteSize: Schema(type=OpenAPIType.INTEGER),
Expand Down Expand Up @@ -253,71 +247,12 @@ def for_pydantic_model(cls, field_definition: FieldDefinition, schema_creator: S
A schema instance.
"""

annotation = field_definition.annotation
if is_generic(annotation):
is_generic_model = True
model = pydantic_unwrap_and_get_origin(annotation) or annotation
else:
is_generic_model = False
model = annotation

if is_pydantic_2_model(model):
model_config = model.model_config
model_field_info = model.model_fields
title = model_config.get("title")
example = model_config.get("example")
is_v2_model = True
else:
model_config = annotation.__config__
model_field_info = model.__fields__
title = getattr(model_config, "title", None)
example = getattr(model_config, "example", None)
is_v2_model = False

model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] = { # pyright: ignore
k: getattr(f, "field_info", f) for k, f in model_field_info.items()
}

if is_v2_model:
# extract the annotations from the FieldInfo. This allows us to skip fields
# which have been marked as private
model_annotations = {k: field_info.annotation for k, field_info in model_fields.items()} # type: ignore[union-attr]

else:
# pydantic v1 requires some workarounds here
model_annotations = {
k: f.outer_type_ if f.required or f.default else Optional[f.outer_type_]
for k, f in model.__fields__.items()
}

if is_generic_model:
# if the model is generic, resolve the type variables. We pass in the
# already extracted annotations, to keep the logic of respecting private
# fields consistent with the above
model_annotations = pydantic_get_type_hints_with_generics_resolved(
annotation, model_annotations=model_annotations, include_extras=True
)

property_fields = {
field_info.alias if field_info.alias and schema_creator.prefer_alias else k: FieldDefinition.from_kwarg(
annotation=Annotated[model_annotations[k], field_info, field_info.metadata] # type: ignore[union-attr]
if is_v2_model
else Annotated[model_annotations[k], field_info], # pyright: ignore
name=field_info.alias if field_info.alias and schema_creator.prefer_alias else k,
default=Empty if schema_creator.is_undefined(field_info.default) else field_info.default,
)
for k, field_info in model_fields.items()
}

computed_field_definitions = create_field_definitions_for_computed_fields(
annotation, schema_creator.prefer_alias
)
property_fields.update(computed_field_definitions)
model_info = get_model_info(field_definition.annotation, prefer_alias=schema_creator.prefer_alias)

return schema_creator.create_component_schema(
field_definition,
required=sorted(f.name for f in property_fields.values() if f.is_required),
property_fields=property_fields,
title=title,
examples=None if example is None else [example],
required=sorted(f.name for f in model_info.field_definitions.values() if f.is_required),
property_fields=model_info.field_definitions,
title=model_info.title,
examples=None if model_info.example is None else [model_info.example],
)
Loading

0 comments on commit 0e2ad9a

Please sign in to comment.