From 0e2ad9a177252b508bd69c8e23cd248cc6b5f8b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sun, 15 Sep 2024 12:31:13 +0200 Subject: [PATCH] refactor: Metadata handling (#3721) --- .../schema_generation/plugins/struct.py | 31 +- litestar/_openapi/schema_generation/schema.py | 13 +- .../contrib/pydantic/pydantic_dto_factory.py | 76 +-- .../contrib/pydantic/pydantic_init_plugin.py | 14 +- .../pydantic/pydantic_schema_plugin.py | 83 +-- litestar/contrib/pydantic/utils.py | 286 +++++++++- litestar/dto/_backend.py | 7 +- litestar/dto/_types.py | 1 + litestar/dto/data_structures.py | 6 + litestar/dto/msgspec_dto.py | 14 +- litestar/plugins/core.py | 31 -- litestar/plugins/core/__init__.py | 3 + litestar/plugins/core/_msgspec.py | 86 +++ litestar/typing.py | 214 +++----- pyproject.toml | 4 + tests/unit/test_contrib/conftest.py | 82 --- tests/unit/test_contrib/test_msgspec.py | 81 ++- .../test_contrib/test_pydantic/conftest.py | 10 - .../test_contrib/test_pydantic/test_dto.py | 3 +- .../test_pydantic/test_integration.py | 54 +- .../test_pydantic/test_openapi.py | 495 +++++++++++------- .../test_pydantic_dto_factory.py | 89 +++- .../test_pydantic/test_schema_plugin.py | 19 + tests/unit/test_openapi/test_schema.py | 25 +- tests/unit/test_typing.py | 17 +- 25 files changed, 1105 insertions(+), 639 deletions(-) delete mode 100644 litestar/plugins/core.py create mode 100644 litestar/plugins/core/__init__.py create mode 100644 litestar/plugins/core/_msgspec.py diff --git a/litestar/_openapi/schema_generation/plugins/struct.py b/litestar/_openapi/schema_generation/plugins/struct.py index da6d8d8c6b..7ac0dd0220 100644 --- a/litestar/_openapi/schema_generation/plugins/struct.py +++ b/litestar/_openapi/schema_generation/plugins/struct.py @@ -4,16 +4,14 @@ import msgspec from msgspec import Struct -from msgspec.structs import fields from litestar.plugins import OpenAPISchemaPlugin +from litestar.plugins.core._msgspec import kwarg_definition_from_field from litestar.types.empty import Empty from litestar.typing import FieldDefinition from litestar.utils.predicates import is_optional_union if TYPE_CHECKING: - from msgspec.structs import FieldInfo - from litestar._openapi.schema_generation import SchemaCreator from litestar.openapi.spec import Schema @@ -23,11 +21,25 @@ def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: return not field_definition.is_union and field_definition.is_subclass_of(Struct) def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: - def is_field_required(field: FieldInfo) -> bool: + def is_field_required(field: msgspec.inspect.Field) -> bool: return field.required or field.default_factory is Empty type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) - struct_fields = fields(field_definition.type_) + struct_info: msgspec.inspect.StructType = msgspec.inspect.type_info(field_definition.type_) # type: ignore[assignment] + struct_fields = struct_info.fields + + property_fields = {} + for field in struct_fields: + field_definition_kwargs = {} + if kwarg_definition := kwarg_definition_from_field(field)[0]: + field_definition_kwargs["kwarg_definition"] = kwarg_definition + + property_fields[field.encode_name] = FieldDefinition.from_annotation( + annotation=type_hints[field.name], + name=field.encode_name, + default=field.default if field.default not in {msgspec.NODEFAULT, msgspec.UNSET} else Empty, + **field_definition_kwargs, + ) return schema_creator.create_component_schema( field_definition, @@ -38,12 +50,5 @@ def is_field_required(field: FieldInfo) -> bool: if is_field_required(field=field) and not is_optional_union(type_hints[field.name]) ] ), - property_fields={ - field.encode_name: FieldDefinition.from_kwarg( - type_hints[field.name], - field.encode_name, - default=field.default if field.default not in {msgspec.NODEFAULT, msgspec.UNSET} else Empty, - ) - for field in struct_fields - }, + property_fields=property_fields, ) diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index defa0e0717..1951154006 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -347,10 +347,12 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re result = self.for_union_field(field_definition) elif field_definition.is_type_var: result = self.for_typevar() - elif field_definition.inner_types and not field_definition.is_generic: - result = self.for_object_type(field_definition) elif self.is_constrained_field(field_definition): result = self.for_constrained_field(field_definition) + elif field_definition.inner_types and not field_definition.is_generic: + # this case does not recurse for all base cases, so it needs to happen + # after all non-concrete cases + result = self.for_object_type(field_definition) elif field_definition.is_subclass_of(UploadFile): result = self.for_upload_file(field_definition) else: @@ -564,12 +566,7 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) -> if field_definition.inner_types: items = list(map(item_creator.for_field_definition, field_definition.inner_types)) schema.items = Schema(one_of=items) if len(items) > 1 else items[0] - else: - schema.items = item_creator.for_field_definition( - FieldDefinition.from_kwarg( - field_definition.annotation.item_type, f"{field_definition.annotation.__name__}Field" - ) - ) + # INFO: Removed because it was only for pydantic constrained collections return schema def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference: diff --git a/litestar/contrib/pydantic/pydantic_dto_factory.py b/litestar/contrib/pydantic/pydantic_dto_factory.py index 4322aa4c32..c6e5e5641b 100644 --- a/litestar/contrib/pydantic/pydantic_dto_factory.py +++ b/litestar/contrib/pydantic/pydantic_dto_factory.py @@ -7,7 +7,7 @@ from typing_extensions import Annotated, TypeAlias, override -from litestar.contrib.pydantic.utils import is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2 +from litestar.contrib.pydantic.utils import get_model_info, is_pydantic_2_model, is_pydantic_undefined, is_pydantic_v2 from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field @@ -109,50 +109,56 @@ def decode_bytes(self, value: bytes) -> Any: def generate_field_definitions( cls, model_type: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel] ) -> Generator[DTOFieldDefinition, None, None]: - model_field_definitions = cls.get_model_type_hints(model_type) + model_info = get_model_info(model_type) + model_fields = model_info.model_fields + model_field_definitions = model_info.field_definitions - model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] - try: - model_fields = dict(model_type.model_fields) # type: ignore[union-attr] - except AttributeError: - model_fields = { - k: model_field.field_info - for k, model_field in model_type.__fields__.items() # type: ignore[union-attr] - } - - for field_name, field_info in model_fields.items(): - field_definition = downtype_for_data_transfer(model_field_definitions[field_name]) + for field_name, field_definition in model_field_definitions.items(): + field_definition = downtype_for_data_transfer(field_definition) dto_field = extract_dto_field(field_definition, field_definition.extra) - try: - extra = field_info.extra # type: ignore[union-attr] - except AttributeError: - extra = field_info.json_schema_extra # type: ignore[union-attr] - - if extra is not None and extra.pop(DTO_FIELD_META_KEY, None): - warn( - message="Declaring 'DTOField' via Pydantic's 'Field.extra' is deprecated. " - "Use 'Annotated', e.g., 'Annotated[str, DTOField(mark='read-only')]' instead. " - "Support for 'DTOField' in 'Field.extra' will be removed in v3.", - category=DeprecationWarning, - stacklevel=2, + default: Any = Empty + default_factory: Any = None + if field_info := model_fields.get(field_name): + # field_info might not exist, since FieldInfo isn't provided by pydantic + # for computed fields, but we still generate a FieldDefinition for them + try: + extra = field_info.extra # type: ignore[union-attr] + except AttributeError: + extra = field_info.json_schema_extra # type: ignore[union-attr] + + if extra is not None and extra.pop(DTO_FIELD_META_KEY, None): + warn( + message="Declaring 'DTOField' via Pydantic's 'Field.extra' is deprecated. " + "Use 'Annotated', e.g., 'Annotated[str, DTOField(mark='read-only')]' instead. " + "Support for 'DTOField' in 'Field.extra' will be removed in v3.", + category=DeprecationWarning, + stacklevel=2, + ) + + if not is_pydantic_undefined(field_info.default): + default = field_info.default + elif field_definition.is_optional: + default = None + else: + default = Empty + + default_factory = ( + field_info.default_factory + if field_info.default_factory and not is_pydantic_undefined(field_info.default_factory) + else None ) - if not is_pydantic_undefined(field_info.default): - default = field_info.default - elif field_definition.is_optional: - default = None - else: - default = Empty - yield replace( DTOFieldDefinition.from_field_definition( field_definition=field_definition, dto_field=dto_field, model_name=model_type.__name__, - default_factory=field_info.default_factory - if field_info.default_factory and not is_pydantic_undefined(field_info.default_factory) - else None, + default_factory=default_factory, + # we don't want the constraints to be set on the DTO struct as + # constraints, but as schema metadata only, so we can let pydantic + # handle all the constraining + passthrough_constraints=False, ), default=default, name=field_name, diff --git a/litestar/contrib/pydantic/pydantic_init_plugin.py b/litestar/contrib/pydantic/pydantic_init_plugin.py index 1d425f3420..ff7cbfff3a 100644 --- a/litestar/contrib/pydantic/pydantic_init_plugin.py +++ b/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -9,10 +9,9 @@ from typing_extensions import Buffer, TypeGuard from litestar._signature.types import ExtendedMsgSpecValidationError -from litestar.contrib.pydantic.utils import is_pydantic_constrained_field, is_pydantic_v2 +from litestar.contrib.pydantic.utils import is_pydantic_v2 from litestar.exceptions import MissingDependencyException from litestar.plugins import InitPluginProtocol -from litestar.typing import _KWARG_META_EXTRACTORS from litestar.utils import is_class_and_subclass try: @@ -114,16 +113,6 @@ def is_pydantic_v2_model_class(annotation: Any) -> TypeGuard[type[pydantic_v2.Ba return is_class_and_subclass(annotation, pydantic_v2.BaseModel) -class ConstrainedFieldMetaExtractor: - @staticmethod - def matches(annotation: Any, name: str | None, default: Any) -> bool: - return is_pydantic_constrained_field(annotation) - - @staticmethod - def extract(annotation: Any, default: Any) -> Any: - return [annotation] - - class PydanticInitPlugin(InitPluginProtocol): __slots__ = ( "exclude", @@ -292,5 +281,4 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: *(app_config.type_decoders or []), ] - _KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor) return app_config diff --git a/litestar/contrib/pydantic/pydantic_schema_plugin.py b/litestar/contrib/pydantic/pydantic_schema_plugin.py index 2eda65f2fc..01fe5f6781 100644 --- a/litestar/contrib/pydantic/pydantic_schema_plugin.py +++ b/litestar/contrib/pydantic/pydantic_schema_plugin.py @@ -1,25 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional - -from typing_extensions import Annotated +from typing import TYPE_CHECKING, Any from litestar.contrib.pydantic.utils import ( - create_field_definitions_for_computed_fields, - is_pydantic_2_model, + get_model_info, is_pydantic_constrained_field, is_pydantic_model_class, is_pydantic_undefined, is_pydantic_v2, - pydantic_get_type_hints_with_generics_resolved, - pydantic_unwrap_and_get_origin, ) from litestar.exceptions import MissingDependencyException from litestar.openapi.spec import OpenAPIFormat, OpenAPIType, Schema from litestar.plugins import OpenAPISchemaPlugin -from litestar.types import Empty -from litestar.typing import FieldDefinition -from litestar.utils import is_class_and_subclass, is_generic +from litestar.utils import is_class_and_subclass try: import pydantic as _ # noqa: F401 @@ -40,6 +33,7 @@ if TYPE_CHECKING: from litestar._openapi.schema_generation.schema import SchemaCreator + from litestar.typing import FieldDefinition PYDANTIC_TYPE_MAP: dict[type[Any] | None | Any, Schema] = { pydantic_v1.ByteSize: Schema(type=OpenAPIType.INTEGER), @@ -253,71 +247,12 @@ def for_pydantic_model(cls, field_definition: FieldDefinition, schema_creator: S A schema instance. """ - annotation = field_definition.annotation - if is_generic(annotation): - is_generic_model = True - model = pydantic_unwrap_and_get_origin(annotation) or annotation - else: - is_generic_model = False - model = annotation - - if is_pydantic_2_model(model): - model_config = model.model_config - model_field_info = model.model_fields - title = model_config.get("title") - example = model_config.get("example") - is_v2_model = True - else: - model_config = annotation.__config__ - model_field_info = model.__fields__ - title = getattr(model_config, "title", None) - example = getattr(model_config, "example", None) - is_v2_model = False - - model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] = { # pyright: ignore - k: getattr(f, "field_info", f) for k, f in model_field_info.items() - } - - if is_v2_model: - # extract the annotations from the FieldInfo. This allows us to skip fields - # which have been marked as private - model_annotations = {k: field_info.annotation for k, field_info in model_fields.items()} # type: ignore[union-attr] - - else: - # pydantic v1 requires some workarounds here - model_annotations = { - k: f.outer_type_ if f.required or f.default else Optional[f.outer_type_] - for k, f in model.__fields__.items() - } - - if is_generic_model: - # if the model is generic, resolve the type variables. We pass in the - # already extracted annotations, to keep the logic of respecting private - # fields consistent with the above - model_annotations = pydantic_get_type_hints_with_generics_resolved( - annotation, model_annotations=model_annotations, include_extras=True - ) - - property_fields = { - field_info.alias if field_info.alias and schema_creator.prefer_alias else k: FieldDefinition.from_kwarg( - annotation=Annotated[model_annotations[k], field_info, field_info.metadata] # type: ignore[union-attr] - if is_v2_model - else Annotated[model_annotations[k], field_info], # pyright: ignore - name=field_info.alias if field_info.alias and schema_creator.prefer_alias else k, - default=Empty if schema_creator.is_undefined(field_info.default) else field_info.default, - ) - for k, field_info in model_fields.items() - } - - computed_field_definitions = create_field_definitions_for_computed_fields( - annotation, schema_creator.prefer_alias - ) - property_fields.update(computed_field_definitions) + model_info = get_model_info(field_definition.annotation, prefer_alias=schema_creator.prefer_alias) return schema_creator.create_component_schema( field_definition, - required=sorted(f.name for f in property_fields.values() if f.is_required), - property_fields=property_fields, - title=title, - examples=None if example is None else [example], + required=sorted(f.name for f in model_info.field_definitions.values() if f.is_required), + property_fields=model_info.field_definitions, + title=model_info.title, + examples=None if model_info.example is None else [model_info.example], ) diff --git a/litestar/contrib/pydantic/utils.py b/litestar/contrib/pydantic/utils.py index 45362c1506..966d9505a6 100644 --- a/litestar/contrib/pydantic/utils.py +++ b/litestar/contrib/pydantic/utils.py @@ -1,18 +1,24 @@ # mypy: strict-equality=False +# pyright: reportGeneralTypeIssues=false from __future__ import annotations -from typing import TYPE_CHECKING, Any +import datetime +import re +from dataclasses import dataclass +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing_extensions import Annotated, get_type_hints -from litestar.params import KwargDefinition +from litestar.openapi.spec import Example +from litestar.params import KwargDefinition, ParameterKwarg from litestar.types import Empty from litestar.typing import FieldDefinition -from litestar.utils import deprecated, is_class_and_subclass -from litestar.utils.predicates import is_generic +from litestar.utils import deprecated, is_class_and_subclass, is_generic, is_undefined_sentinel from litestar.utils.typing import ( _substitute_typevars, get_origin_or_inner_type, + get_safe_generic_origin, get_type_hints_with_generics_resolved, normalize_type_annotation, ) @@ -156,7 +162,7 @@ def pydantic_get_type_hints_with_generics_resolved( @deprecated(version="2.6.2") -def pydantic_get_unwrapped_annotation_and_type_hints(annotation: Any) -> tuple[Any, dict[str, Any]]: # pragma: pver +def pydantic_get_unwrapped_annotation_and_type_hints(annotation: Any) -> tuple[Any, dict[str, Any]]: # pragma: no cover """Get the unwrapped annotation and the type hints after resolving generics. Args: @@ -208,7 +214,13 @@ def get_name(k: str, dec: Any) -> str: (name := get_name(k, dec)): FieldDefinition.from_annotation( Annotated[ dec.info.return_type, - KwargDefinition(title=dec.info.title, description=dec.info.description, read_only=True), + KwargDefinition( + title=dec.info.title, + description=dec.info.description, + read_only=True, + examples=[Example(value=v) for v in examples] if (examples := dec.info.examples) else None, + schema_extra=dec.info.json_schema_extra, + ), ], name=name, ) @@ -228,3 +240,265 @@ def is_pydantic_v2(module: ModuleType) -> bool: True if the module is pydantic v2, otherwise False. """ return bool(module.__version__.startswith("2.")) + + +@dataclass(frozen=True) +class PydanticModelInfo: + pydantic_version: Literal["1", "2"] + field_definitions: dict[str, FieldDefinition] + model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] + title: str | None = None + example: Any | None = None + is_generic: bool = False + + +_CreateFieldDefinition = Callable[..., FieldDefinition] + + +def _create_field_definition_v1( # noqa: C901 + field_annotation: Any, + *, + field_info: pydantic_v1.fields.FieldInfo, + **field_definition_kwargs: Any, +) -> FieldDefinition: + kwargs: dict[str, Any] = {} + examples: list[Any] = [] + if example := field_info.extra.get("example"): + examples.append(example) + if extra_examples := field_info.extra.get("examples"): + examples.extend(extra_examples) + if examples: + kwargs["examples"] = [Example(value=e) for e in examples] + if title := field_info.title: + kwargs["title"] = title + if description := field_info.description: + kwargs["description"] = description + + kwarg_definition: KwargDefinition | None = None + + if isclass(field_annotation): + if issubclass(field_annotation, pydantic_v1.ConstrainedBytes): + kwarg_definition = ParameterKwarg( + min_length=field_annotation.min_length, + max_length=field_annotation.max_length, + lower_case=field_annotation.to_lower, + upper_case=field_annotation.to_upper, + **kwargs, + ) + field_definition_kwargs["raw"] = field_annotation + field_annotation = bytes + elif issubclass(field_annotation, pydantic_v1.ConstrainedStr): + kwarg_definition = ParameterKwarg( + min_length=field_annotation.min_length, + max_length=field_annotation.max_length, + lower_case=field_annotation.to_lower, + upper_case=field_annotation.to_upper, + pattern=field_annotation.regex.pattern + if isinstance(field_annotation.regex, re.Pattern) + else field_annotation.regex, + **kwargs, + ) + field_definition_kwargs["raw"] = field_annotation + field_annotation = str + elif issubclass(field_annotation, pydantic_v1.ConstrainedDate): + # TODO: The typings of ParameterKwarg need fixing. Specifically, the + # gt/ge/lt/le fields need to be typed with protocols, such that they may + # accept any type that implements the respective comparisons + + kwarg_definition = ParameterKwarg( + gt=field_annotation.gt, # type: ignore[arg-type] + ge=field_annotation.ge, # type: ignore[arg-type] + lt=field_annotation.lt, # type: ignore[arg-type] + le=field_annotation.le, # type: ignore[arg-type] + **kwargs, + ) + field_definition_kwargs["raw"] = field_annotation + field_annotation = datetime.date + elif issubclass( + field_annotation, + (pydantic_v1.ConstrainedInt, pydantic_v1.ConstrainedFloat, pydantic_v1.ConstrainedDecimal), + ): + kwarg_definition = ParameterKwarg( + gt=field_annotation.gt, # type: ignore[arg-type] + ge=field_annotation.ge, # type: ignore[arg-type] + lt=field_annotation.lt, # type: ignore[arg-type] + le=field_annotation.le, # type: ignore[arg-type] + multiple_of=field_annotation.multiple_of, # type: ignore[arg-type] + **kwargs, + ) + field_definition_kwargs["raw"] = field_annotation + field_annotation = field_annotation.mro()[2] + elif issubclass( + field_annotation, + (pydantic_v1.ConstrainedList, pydantic_v1.ConstrainedSet, pydantic_v1.ConstrainedFrozenSet), + ): + kwarg_definition = ParameterKwarg( + max_items=field_annotation.max_items, min_items=field_annotation.min_items, **kwargs + ) + field_definition_kwargs["raw"] = field_annotation + # on < 3.9, these builtins are not generic + origin = get_safe_generic_origin(None, field_annotation.__origin__) + field_annotation = origin[field_annotation.item_type] + + if kwarg_definition is None and kwargs: + kwarg_definition = ParameterKwarg(**kwargs) + + if kwarg_definition: + field_definition_kwargs["raw"] = field_annotation + field_annotation = Annotated[field_annotation, kwarg_definition] + + return FieldDefinition.from_annotation( + annotation=field_annotation, + **field_definition_kwargs, + ) + + +def _create_field_definition_v2( # noqa: C901 + field_annotation: Any, + *, + field_info: pydantic_v2.fields.FieldInfo, + **field_definition_kwargs: Any, +) -> FieldDefinition: + kwargs: dict[str, Any] = {} + examples: list[Any] = [] + field_meta: list[Any] = [] + + if json_schema_extra := field_info.json_schema_extra: + if callable(json_schema_extra): + raise ValueError("Callables not supported for json_schema_extra") + if json_schema_example := json_schema_extra.get("example"): + del json_schema_extra["example"] + examples.append(json_schema_example) + if json_schema_examples := json_schema_extra.get("examples"): + del json_schema_extra["examples"] + examples.extend(json_schema_examples) # type: ignore[arg-type] + if field_examples := field_info.examples: + examples.extend(field_examples) + + if examples: + if not json_schema_extra: + json_schema_extra = {} + json_schema_extra["examples"] = examples + + if description := field_info.description: + kwargs["description"] = description + + if title := field_info.title: + kwargs["title"] = title + + for meta in field_info.metadata: + if isinstance(meta, pydantic_v2.types.StringConstraints): + kwargs["min_length"] = meta.min_length + kwargs["max_length"] = meta.max_length + kwargs["pattern"] = meta.pattern + kwargs["lower_case"] = meta.to_lower + kwargs["upper_case"] = meta.to_upper + # forward other metadata + else: + field_meta.append(meta) + + if json_schema_extra: + kwargs["schema_extra"] = json_schema_extra + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if kwargs: + kwarg_definition = ParameterKwarg(**kwargs) + field_meta.append(kwarg_definition) + + if field_meta: + field_definition_kwargs["raw"] = field_annotation + for meta in field_meta: + field_annotation = Annotated[field_annotation, meta] + + return FieldDefinition.from_annotation( + annotation=field_annotation, + **field_definition_kwargs, + ) + + +def get_model_info( + annotation: Any, + prefer_alias: bool = False, +) -> PydanticModelInfo: + model: type[pydantic_v1.BaseModel | pydantic_v2.BaseModel] + + if is_generic(annotation): + is_generic_model = True + model = pydantic_unwrap_and_get_origin(annotation) or annotation + else: + is_generic_model = False + model = annotation + + if is_pydantic_2_model(model): + model_config = model.model_config + model_field_info = model.model_fields + title = model_config.get("title") + example = model_config.get("example") + is_v2_model = True + else: + model_config = model.__config__ # type: ignore[assignment, union-attr] + model_field_info = model.__fields__ # type: ignore[assignment] + title = getattr(model_config, "title", None) + example = getattr(model_config, "example", None) + is_v2_model = False + + model_fields: dict[str, pydantic_v1.fields.FieldInfo | pydantic_v2.fields.FieldInfo] = { # pyright: ignore + k: getattr(f, "field_info", f) for k, f in model_field_info.items() + } + + if is_v2_model: + # extract the annotations from the FieldInfo. This allows us to skip fields + # which have been marked as private + # if there's a default factory, we wrap the field in 'Optional', to signal + # that it is not required + model_annotations = { + k: Optional[field_info.annotation] if field_info.default_factory else field_info.annotation # type: ignore[union-attr] + for k, field_info in model_fields.items() + } + + else: + # pydantic v1 requires some workarounds here + model_annotations = { + k: f.outer_type_ if f.required or f.default else Optional[f.outer_type_] + for k, f in model.__fields__.items() # type: ignore[union-attr] + } + + if is_generic_model: + # if the model is generic, resolve the type variables. We pass in the + # already extracted annotations, to keep the logic of respecting private + # fields consistent with the above + model_annotations = pydantic_get_type_hints_with_generics_resolved( + annotation, model_annotations=model_annotations, include_extras=True + ) + + create_field_definition: _CreateFieldDefinition = ( + _create_field_definition_v2 if is_v2_model else _create_field_definition_v1 # type: ignore[assignment] + ) + + property_fields = { + field_info.alias if field_info.alias and prefer_alias else k: create_field_definition( + field_annotation=model_annotations[k], + name=field_info.alias if field_info.alias and prefer_alias else k, + default=Empty + if is_undefined_sentinel(field_info.default) or is_pydantic_undefined(field_info.default) + else field_info.default, + field_info=field_info, + ) + for k, field_info in model_fields.items() + } + + computed_field_definitions = create_field_definitions_for_computed_fields( + model, + prefer_alias=prefer_alias, + ) + property_fields.update(computed_field_definitions) + + return PydanticModelInfo( + pydantic_version="2" if is_v2_model else "1", + title=title, + example=example, + field_definitions=property_fields, + is_generic=is_generic_model, + model_fields=model_fields, + ) diff --git a/litestar/dto/_backend.py b/litestar/dto/_backend.py index 167a2a9740..4d9651a19e 100644 --- a/litestar/dto/_backend.py +++ b/litestar/dto/_backend.py @@ -805,8 +805,11 @@ def _create_struct_for_field_definitions( if field_definition.is_partial: field_type = Union[field_type, UnsetType] - if (field_meta := _create_struct_field_meta_for_field_definition(field_definition)) is not None: - field_type = Annotated[field_type, field_meta] + if field_definition.passthrough_constraints: + if (field_meta := _create_struct_field_meta_for_field_definition(field_definition)) is not None: + field_type = Annotated[field_type, field_meta] + elif field_definition.kwarg_definition: + field_type = Annotated[field_type, field_definition.kwarg_definition] struct_fields.append( ( diff --git a/litestar/dto/_types.py b/litestar/dto/_types.py index b0863b2593..01f0191ddd 100644 --- a/litestar/dto/_types.py +++ b/litestar/dto/_types.py @@ -142,4 +142,5 @@ def from_dto_field_definition( transfer_type=transfer_type, type_wrappers=field_definition.type_wrappers, model_name=field_definition.model_name, + passthrough_constraints=field_definition.passthrough_constraints, ) diff --git a/litestar/dto/data_structures.py b/litestar/dto/data_structures.py index a5c3386f1c..b660cddd72 100644 --- a/litestar/dto/data_structures.py +++ b/litestar/dto/data_structures.py @@ -68,6 +68,7 @@ class DTOFieldDefinition(FieldDefinition): "default_factory", "dto_field", "model_name", + "passthrough_constraints", ) model_name: str @@ -76,6 +77,8 @@ class DTOFieldDefinition(FieldDefinition): """Default factory of the field.""" dto_field: DTOField """DTO field configuration.""" + passthrough_constraints: bool + """Pass constraints of the source annotation to be validated by the DTO backend""" @classmethod def from_field_definition( @@ -84,6 +87,7 @@ def from_field_definition( model_name: str, default_factory: Callable[[], Any] | None, dto_field: DTOField, + passthrough_constraints: bool = True, ) -> DTOFieldDefinition: """Create a :class:`FieldDefinition` from a :class:`FieldDefinition`. @@ -92,6 +96,7 @@ def from_field_definition( model_name: The name of the model. default_factory: Default factory function, if any. dto_field: DTOField instance. + passthrough_constraints: Pass constraints of the source annotation to be validated by the DTO backend Returns: A :class:`FieldDefinition` instance. @@ -113,4 +118,5 @@ def from_field_definition( raw=field_definition.raw, safe_generic_origin=field_definition.safe_generic_origin, type_wrappers=field_definition.type_wrappers, + passthrough_constraints=passthrough_constraints, ) diff --git a/litestar/dto/msgspec_dto.py b/litestar/dto/msgspec_dto.py index 9996319747..c0a2d4b633 100644 --- a/litestar/dto/msgspec_dto.py +++ b/litestar/dto/msgspec_dto.py @@ -1,13 +1,16 @@ from __future__ import annotations +import dataclasses from dataclasses import replace from typing import TYPE_CHECKING, Generic, TypeVar +import msgspec.inspect from msgspec import NODEFAULT, Struct, structs from litestar.dto.base_dto import AbstractDTO from litestar.dto.data_structures import DTOFieldDefinition from litestar.dto.field import DTO_FIELD_META_KEY, extract_dto_field +from litestar.plugins.core._msgspec import kwarg_definition_from_field from litestar.types.empty import Empty if TYPE_CHECKING: @@ -28,16 +31,25 @@ class MsgspecDTO(AbstractDTO[T], Generic[T]): def generate_field_definitions(cls, model_type: type[Struct]) -> Generator[DTOFieldDefinition, None, None]: msgspec_fields = {f.name: f for f in structs.fields(model_type)} + # TODO: Move out of here def default_or_empty(value: Any) -> Any: return Empty if value is NODEFAULT else value def default_or_none(value: Any) -> Any: return None if value is NODEFAULT else value + inspect_fields: dict[str, msgspec.inspect.Field] = { + field.name: field + for field in msgspec.inspect.type_info(model_type).fields # type: ignore[attr-defined] + } + for key, field_definition in cls.get_model_type_hints(model_type).items(): - msgspec_field = msgspec_fields[key] + kwarg_definition, extra = kwarg_definition_from_field(inspect_fields[key]) + field_definition = dataclasses.replace(field_definition, kwarg_definition=kwarg_definition) + field_definition.extra.update(extra) dto_field = extract_dto_field(field_definition, field_definition.extra) field_definition.extra.pop(DTO_FIELD_META_KEY, None) + msgspec_field = msgspec_fields[key] yield replace( DTOFieldDefinition.from_field_definition( diff --git a/litestar/plugins/core.py b/litestar/plugins/core.py deleted file mode 100644 index 010250e103..0000000000 --- a/litestar/plugins/core.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -import inspect -from inspect import Signature -from typing import Any - -import msgspec - -from litestar.plugins import DIPlugin - -__all__ = ("MsgspecDIPlugin",) - - -class MsgspecDIPlugin(DIPlugin): - def has_typed_init(self, type_: Any) -> bool: - return type(type_) is type(msgspec.Struct) - - def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: - parameters = [] - type_hints = {} - for field_info in msgspec.structs.fields(type_): - type_hints[field_info.name] = field_info.type - parameters.append( - inspect.Parameter( - name=field_info.name, - kind=inspect.Parameter.KEYWORD_ONLY, - annotation=field_info.type, - default=field_info.default, - ) - ) - return inspect.Signature(parameters), type_hints diff --git a/litestar/plugins/core/__init__.py b/litestar/plugins/core/__init__.py new file mode 100644 index 0000000000..802bcf3259 --- /dev/null +++ b/litestar/plugins/core/__init__.py @@ -0,0 +1,3 @@ +from ._msgspec import MsgspecDIPlugin + +__all__ = ("MsgspecDIPlugin",) diff --git a/litestar/plugins/core/_msgspec.py b/litestar/plugins/core/_msgspec.py new file mode 100644 index 0000000000..2aa06c7247 --- /dev/null +++ b/litestar/plugins/core/_msgspec.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import dataclasses +import inspect +from inspect import Signature +from typing import Any + +import msgspec + +from litestar.openapi.spec import Example +from litestar.params import ParameterKwarg +from litestar.plugins import DIPlugin + +__all__ = ("MsgspecDIPlugin", "kwarg_definition_from_field") + + +class MsgspecDIPlugin(DIPlugin): + def has_typed_init(self, type_: Any) -> bool: + return type(type_) is type(msgspec.Struct) + + def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]: + parameters = [] + type_hints = {} + for field_info in msgspec.structs.fields(type_): + type_hints[field_info.name] = field_info.type + parameters.append( + inspect.Parameter( + name=field_info.name, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.type, + default=field_info.default, + ) + ) + return inspect.Signature(parameters), type_hints + + +def kwarg_definition_from_field(field: msgspec.inspect.Field) -> tuple[ParameterKwarg | None, dict[str, Any]]: + extra: dict[str, Any] = {} + kwargs: dict[str, Any] = {} + if isinstance(field.type, msgspec.inspect.Metadata): + meta = field.type + field_type = meta.type + if extra_json_schema := meta.extra_json_schema: + kwargs["title"] = extra_json_schema.get("title") + kwargs["description"] = extra_json_schema.get("description") + if examples := extra_json_schema.get("examples"): + kwargs["examples"] = [Example(value=e) for e in examples] + kwargs["schema_extra"] = extra_json_schema.get("extra") + extra = meta.extra or {} + else: + field_type = field.type + + if isinstance( + field_type, + ( + msgspec.inspect.IntType, + msgspec.inspect.FloatType, + ), + ): + kwargs["gt"] = field_type.gt + kwargs["ge"] = field_type.ge + kwargs["lt"] = field_type.lt + kwargs["le"] = field_type.le + kwargs["multiple_of"] = field_type.multiple_of + elif isinstance( + field_type, + ( + msgspec.inspect.StrType, + msgspec.inspect.BytesType, + msgspec.inspect.ByteArrayType, + msgspec.inspect.MemoryViewType, + ), + ): + kwargs["min_length"] = field_type.min_length + kwargs["max_length"] = field_type.max_length + if isinstance(field_type, msgspec.inspect.StrType): + kwargs["pattern"] = field_type.pattern + + parameter_defaults = { + f.name: default for f in dataclasses.fields(ParameterKwarg) if (default := f.default) is not dataclasses.MISSING + } + kwargs_without_defaults = {k: v for k, v in kwargs.items() if v != parameter_defaults[k]} + + if kwargs_without_defaults: + return ParameterKwarg(**kwargs_without_defaults), extra + return None, extra diff --git a/litestar/typing.py b/litestar/typing.py index 4eb3bbe9cc..45f2f1a695 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -1,12 +1,14 @@ from __future__ import annotations +import dataclasses import warnings -from collections import abc, deque +from collections import abc from copy import deepcopy from dataclasses import dataclass, is_dataclass, replace from inspect import Parameter, Signature -from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast +from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, TypeVar, cast +import annotated_types from msgspec import UnsetType from typing_extensions import ( NewType, @@ -21,7 +23,6 @@ ) from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning -from litestar.openapi.spec import Example from litestar.params import BodyKwarg, DependencyKwarg, KwargDefinition, ParameterKwarg from litestar.types import Empty from litestar.types.builtin_types import NoneType, UnionTypes @@ -46,125 +47,42 @@ T = TypeVar("T", bound=KwargDefinition) -class _KwargMetaExtractor(Protocol): - @staticmethod - def matches(annotation: Any, name: str | None, default: Any) -> bool: ... - - @staticmethod - def extract(annotation: Any, default: Any) -> Any: ... - - -_KWARG_META_EXTRACTORS: set[_KwargMetaExtractor] = set() - - -def _unpack_predicate(value: Any) -> dict[str, Any]: - try: - from annotated_types import Predicate - - if isinstance(value, Predicate): - if value.func == str.islower: - return {"lower_case": True} - if value.func == str.isupper: - return {"upper_case": True} - if value.func == str.isascii: - return {"pattern": "[[:ascii:]]"} - if value.func == str.isdigit: - return {"pattern": "[[:digit:]]"} - except ImportError: - pass - - return {} - - -def _parse_metadata(value: Any, is_sequence_container: bool, extra: dict[str, Any] | None) -> dict[str, Any]: - """Parse metadata from a value. - - Args: - value: A metadata value from annotation, namely anything stored under Annotated[x, metadata...] - is_sequence_container: Whether the type is a sequence container (list, tuple etc...) - extra: Extra key values to parse. - - Returns: - A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. - """ - extra = { - **cast("dict[str, Any]", extra or getattr(value, "extra", None) or {}), - **(getattr(value, "json_schema_extra", None) or {}), - } - example_list: list[Any] | None - if example := extra.pop("example", None): - example_list = [Example(value=example)] - elif examples := (extra.pop("examples", None) or getattr(value, "examples", None)): - example_list = [Example(value=example) for example in cast("list[str]", examples)] - else: - example_list = None - - return { - k: v - for k, v in { - "gt": getattr(value, "gt", None), - "ge": getattr(value, "ge", None), - "lt": getattr(value, "lt", None), - "le": getattr(value, "le", None), - "multiple_of": getattr(value, "multiple_of", None), - "min_length": None if is_sequence_container else getattr(value, "min_length", None), - "max_length": None if is_sequence_container else getattr(value, "max_length", None), - "description": getattr(value, "description", None), - "examples": example_list, - "title": getattr(value, "title", None), - "lower_case": getattr(value, "to_lower", None), - "upper_case": getattr(value, "to_upper", None), - "pattern": getattr(value, "regex", getattr(value, "pattern", None)), - "min_items": getattr(value, "min_items", getattr(value, "min_length", None)) - if is_sequence_container - else None, - "max_items": getattr(value, "max_items", getattr(value, "max_length", None)) - if is_sequence_container - else None, - "const": getattr(value, "const", None) is not None, - **extra, - }.items() - if v is not None - } - - -def _traverse_metadata( - metadata: Sequence[Any], is_sequence_container: bool, extra: dict[str, Any] | None -) -> dict[str, Any]: - """Recursively traverse metadata from a value. - - Args: - metadata: A list of metadata values from annotation, namely anything stored under Annotated[x, metadata...] - is_sequence_container: Whether the container is a sequence container (list, tuple etc...) - extra: Extra key values to parse. - - Returns: - A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. - """ - constraints: dict[str, Any] = {} - for value in metadata: - if isinstance(value, (list, set, frozenset, deque)): - constraints.update( - _traverse_metadata( - metadata=cast("Sequence[Any]", value), is_sequence_container=is_sequence_container, extra=extra - ) - ) - elif unpacked_predicate := _unpack_predicate(value): - constraints.update(unpacked_predicate) +def _annotated_types_extractor(meta: Any, is_sequence_container: bool) -> dict[str, Any]: # noqa: C901 + kwargs = {} + if isinstance(meta, annotated_types.GroupedMetadata): + for sub_meta in meta: + kwargs.update(_annotated_types_extractor(sub_meta, is_sequence_container=is_sequence_container)) + return kwargs + if isinstance(meta, annotated_types.Gt): + kwargs["gt"] = meta.gt + elif isinstance(meta, annotated_types.Ge): + kwargs["ge"] = meta.ge + elif isinstance(meta, annotated_types.Lt): + kwargs["lt"] = meta.lt + elif isinstance(meta, annotated_types.Le): + kwargs["le"] = meta.le + elif isinstance(meta, annotated_types.MultipleOf): + kwargs["multiple_of"] = meta.multiple_of + elif isinstance(meta, annotated_types.MinLen): + if is_sequence_container: + kwargs["min_items"] = meta.min_length else: - constraints.update(_parse_metadata(value=value, is_sequence_container=is_sequence_container, extra=extra)) - return constraints - - -def _create_metadata_from_type( - metadata: Sequence[Any], model: type[T], annotation: Any, extra: dict[str, Any] | None -) -> tuple[T | None, dict[str, Any]]: - is_sequence_container = is_non_string_sequence(annotation) - result = _traverse_metadata(metadata=metadata, is_sequence_container=is_sequence_container, extra=extra) - - constraints = {k: v for k, v in result.items() if k in dir(model)} - extra = {k: v for k, v in result.items() if k not in constraints} - return model(**constraints) if constraints else None, extra + kwargs["min_length"] = meta.min_length + elif isinstance(meta, annotated_types.MaxLen): + if is_sequence_container: + kwargs["max_items"] = meta.max_length + else: + kwargs["max_length"] = meta.max_length + elif isinstance(meta, annotated_types.Predicate): + if meta.func == str.islower: + kwargs["lower_case"] = True + elif meta.func == str.isupper: + kwargs["upper_case"] = True + elif meta.func == str.isascii: + kwargs["pattern"] = "[[:ascii:]]" + elif meta.func == str.isdigit: # pragma: no cover # coverage quirk: It expects a jump here for branch coverage + kwargs["pattern"] = "[[:digit:]]" + return kwargs @dataclass(frozen=True) @@ -232,29 +150,6 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash((self.name, self.raw, self.annotation, self.origin, self.inner_types)) - @classmethod - def _extract_metadata( - cls, annotation: Any, name: str | None, default: Any, metadata: tuple[Any, ...], extra: dict[str, Any] | None - ) -> tuple[KwargDefinition | None, dict[str, Any]]: - model = BodyKwarg if name == "data" else ParameterKwarg - - for extractor in _KWARG_META_EXTRACTORS: - if extractor.matches(annotation=annotation, name=name, default=default): - return _create_metadata_from_type( - extractor.extract(annotation=annotation, default=default), - model=model, - annotation=annotation, - extra=extra, - ) - - if any(isinstance(arg, KwargDefinition) for arg in get_args(annotation)): - return next(arg for arg in get_args(annotation) if isinstance(arg, KwargDefinition)), extra or {} - - if metadata: - return _create_metadata_from_type(metadata=metadata, model=model, annotation=annotation, extra=extra) - - return None, {} - @property def has_default(self) -> bool: """Check if the field has a default value. @@ -511,7 +406,7 @@ def from_annotation(cls, annotation: Any, **kwargs: Any) -> FieldDefinition: unwrapped, metadata, wrappers = unwrap_annotation(annotation if annotation is not Empty else Any) origin = get_origin(unwrapped) - args = () if origin is abc.Callable else get_args(unwrapped) + annotation_args = () if origin is abc.Callable else get_args(unwrapped) if not kwargs.get("kwarg_definition"): if isinstance(kwargs.get("default"), (KwargDefinition, DependencyKwarg)): @@ -543,20 +438,31 @@ def from_annotation(cls, annotation: Any, **kwargs: Any) -> FieldDefinition: metadata = tuple(v for v in metadata if not isinstance(v, (KwargDefinition, DependencyKwarg))) elif (extra := kwargs.get("extra", {})) and "kwarg_definition" in extra: kwargs["kwarg_definition"] = extra.pop("kwarg_definition") - else: - kwargs["kwarg_definition"], kwargs["extra"] = cls._extract_metadata( - annotation=annotation, - name=kwargs.get("name", ""), - default=kwargs.get("default", Empty), - metadata=metadata, - extra=kwargs.get("extra"), + + # there might be additional metadata + if metadata: + kwarg_definition_merge_args = {} + is_sequence_container = is_non_string_sequence(annotation) + # extract metadata into KwargDefinition attributes + for meta in metadata: + kwarg_definition_merge_args.update( + _annotated_types_extractor(meta, is_sequence_container=is_sequence_container) + ) + # if we already have a KwargDefinition, merge it with the additional metadata + if existing_kwargs_definition := kwargs.get("kwarg_definition"): + kwargs["kwarg_definition"] = dataclasses.replace( + existing_kwargs_definition, **kwarg_definition_merge_args ) + # if not, create a new KwargDefinition + else: + model = BodyKwarg if kwargs.get("name") == "data" else ParameterKwarg + kwargs["kwarg_definition"] = model(**kwarg_definition_merge_args) kwargs.setdefault("annotation", unwrapped) - kwargs.setdefault("args", args) + kwargs.setdefault("args", annotation_args) kwargs.setdefault("default", Empty) kwargs.setdefault("extra", {}) - kwargs.setdefault("inner_types", tuple(FieldDefinition.from_annotation(arg) for arg in args)) + kwargs.setdefault("inner_types", tuple(FieldDefinition.from_annotation(arg) for arg in annotation_args)) kwargs.setdefault("instantiable_origin", get_instantiable_origin(origin, unwrapped)) kwargs.setdefault("kwarg_definition", None) kwargs.setdefault("metadata", metadata) diff --git a/pyproject.toml b/pyproject.toml index 93b11561c1..78a34ef862 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -269,6 +269,10 @@ module = ["tests.unit.test_kwargs.test_reserved_kwargs_injection"] module = ["tests.unit.test_contrib.test_repository"] strict_equality = false +[[tool.mypy.overrides]] +module = ["tests.unit.test_contrib.test_pydantic.test_openapi"] +disable_error_code = "index, union-attr" + [[tool.mypy.overrides]] ignore_missing_imports = true module = [ diff --git a/tests/unit/test_contrib/conftest.py b/tests/unit/test_contrib/conftest.py index 5d849f75d0..28ece35c77 100644 --- a/tests/unit/test_contrib/conftest.py +++ b/tests/unit/test_contrib/conftest.py @@ -1,15 +1,9 @@ from __future__ import annotations -from dataclasses import replace from typing import TYPE_CHECKING -from unittest.mock import ANY import pytest -from litestar.dto import DTOField, Mark -from litestar.dto.data_structures import DTOFieldDefinition -from litestar.typing import FieldDefinition - if TYPE_CHECKING: from typing import Callable @@ -17,79 +11,3 @@ @pytest.fixture def int_factory() -> Callable[[], int]: return lambda: 2 - - -@pytest.fixture -def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefinition]: - return [ - DTOFieldDefinition.from_field_definition( - field_definition=FieldDefinition.from_kwarg( - annotation=int, - name="a", - ), - model_name=ANY, - default_factory=None, - dto_field=DTOField(), - ), - replace( - DTOFieldDefinition.from_field_definition( - field_definition=FieldDefinition.from_kwarg( - annotation=int, - name="b", - ), - model_name=ANY, - default_factory=None, - dto_field=DTOField(mark=Mark.READ_ONLY), - ), - metadata=ANY, - type_wrappers=ANY, - raw=ANY, - kwarg_definition=ANY, - ), - replace( - DTOFieldDefinition.from_field_definition( - field_definition=FieldDefinition.from_kwarg( - annotation=int, - name="c", - ), - model_name=ANY, - default_factory=None, - dto_field=DTOField(), - ), - metadata=ANY, - type_wrappers=ANY, - raw=ANY, - kwarg_definition=ANY, - ), - replace( - DTOFieldDefinition.from_field_definition( - field_definition=FieldDefinition.from_kwarg( - annotation=int, - name="d", - default=1, - ), - model_name=ANY, - default_factory=None, - dto_field=DTOField(), - ), - metadata=ANY, - type_wrappers=ANY, - raw=ANY, - kwarg_definition=ANY, - ), - replace( - DTOFieldDefinition.from_field_definition( - field_definition=FieldDefinition.from_kwarg( - annotation=int, - name="e", - ), - model_name=ANY, - default_factory=int_factory, - dto_field=DTOField(), - ), - metadata=ANY, - type_wrappers=ANY, - raw=ANY, - kwarg_definition=ANY, - ), - ] diff --git a/tests/unit/test_contrib/test_msgspec.py b/tests/unit/test_contrib/test_msgspec.py index 9c28ed4ba9..2c0c177ec4 100644 --- a/tests/unit/test_contrib/test_msgspec.py +++ b/tests/unit/test_contrib/test_msgspec.py @@ -1,11 +1,14 @@ from __future__ import annotations +from dataclasses import replace from typing import TYPE_CHECKING +from unittest.mock import ANY +import pytest from msgspec import Meta, Struct, field from typing_extensions import Annotated -from litestar.dto import DTOField, MsgspecDTO, dto_field +from litestar.dto import DTOField, Mark, MsgspecDTO, dto_field from litestar.dto.data_structures import DTOFieldDefinition from litestar.typing import FieldDefinition @@ -13,6 +16,82 @@ from typing import Callable +@pytest.fixture +def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefinition]: + return [ + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="a", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="b", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(mark=Mark.READ_ONLY), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="c", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="d", + default=1, + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="e", + ), + model_name=ANY, + default_factory=int_factory, + dto_field=DTOField(), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + ), + ] + + def test_field_definition_generation( int_factory: Callable[[], int], expected_field_defs: list[DTOFieldDefinition] ) -> None: diff --git a/tests/unit/test_contrib/test_pydantic/conftest.py b/tests/unit/test_contrib/test_pydantic/conftest.py index f56f677ab4..63cfd5aab3 100644 --- a/tests/unit/test_contrib/test_pydantic/conftest.py +++ b/tests/unit/test_contrib/test_pydantic/conftest.py @@ -5,19 +5,9 @@ from pydantic import v1 as pydantic_v1 from pytest import FixtureRequest -from litestar.contrib.pydantic.pydantic_init_plugin import ( # type: ignore[attr-defined] - _KWARG_META_EXTRACTORS, - ConstrainedFieldMetaExtractor, -) - from . import PydanticVersion -@pytest.fixture(autouse=True, scope="session") -def ensure_metadata_extractor_is_added() -> None: - _KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor) - - @pytest.fixture(params=["v1", "v2"]) def pydantic_version(request: FixtureRequest) -> PydanticVersion: return request.param # type: ignore[no-any-return] diff --git a/tests/unit/test_contrib/test_pydantic/test_dto.py b/tests/unit/test_contrib/test_pydantic/test_dto.py index 782d915902..2828a5625a 100644 --- a/tests/unit/test_contrib/test_pydantic/test_dto.py +++ b/tests/unit/test_contrib/test_pydantic/test_dto.py @@ -83,7 +83,7 @@ def test_pydantic_field_descriptions(create_module: Callable[[str], ModuleType]) class User(BaseModel): id: Annotated[ int, - Field(description="This is a test (id description)."), + Field(description="This is a test (id description).", gt=1), ] class DataCollectionDTO(PydanticDTO[User]): @@ -102,6 +102,7 @@ def get_user() -> User: component_schema = schema.components.schemas["GetUserUserResponseBody"] assert component_schema.properties is not None assert component_schema.properties["id"].description == "This is a test (id description)." + assert component_schema.properties["id"].exclusive_minimum == 1 # type: ignore[union-attr] @pytest.mark.parametrize( diff --git a/tests/unit/test_contrib/test_pydantic/test_integration.py b/tests/unit/test_contrib/test_pydantic/test_integration.py index 8673cd40de..c638993b7c 100644 --- a/tests/unit/test_contrib/test_pydantic/test_integration.py +++ b/tests/unit/test_contrib/test_pydantic/test_integration.py @@ -6,7 +6,7 @@ from pydantic import v1 as pydantic_v1 from typing_extensions import Annotated -from litestar import post +from litestar import get, post from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin from litestar.contrib.pydantic.pydantic_dto_factory import PydanticDTO from litestar.enums import RequestEncodingType @@ -386,3 +386,55 @@ async def handler(data: Model) -> None: with create_test_client([handler], plugins=plugins) as client: res = client.post("/", json={"test_bool": "YES"}) assert res.status_code == 400 if expect_error else 201 + + +def test_model_defaults(pydantic_version: PydanticVersion) -> None: + lib = pydantic_v1 if pydantic_version == "v1" else pydantic_v2 + + class Model(lib.BaseModel): # type: ignore[misc, name-defined] + a: int + b: int = lib.Field(default=1) + c: int = lib.Field(default_factory=lambda: 3) + + @post("/") + async def handler(data: Model) -> Dict[str, int]: + return {"a": data.a, "b": data.b, "c": data.c} + + with create_test_client([handler]) as client: + schema = client.app.openapi_schema.components.schemas["test_model_defaults.Model"] + res = client.post("/", json={"a": 5}) + assert res.status_code == 201 + assert res.json() == {"a": 5, "b": 1, "c": 3} + assert schema.required == ["a"] + assert schema.properties["b"].default == 1 + assert schema.properties["c"].default is None + + +@pytest.mark.parametrize("with_dto", [True, False]) +def test_v2_computed_fields(with_dto: bool) -> None: + # /~https://github.com/litestar-org/litestar/issues/3656 + + class Model(pydantic_v2.BaseModel): + foo: int = 1 + + @pydantic_v2.computed_field + def bar(self) -> int: + return 2 + + @pydantic_v2.computed_field(examples=[1], json_schema_extra={"title": "this is computed"}) + def baz(self) -> int: + return 3 + + @get("/", return_dto=PydanticDTO[Model] if with_dto else None) + async def handler() -> Model: + return Model() + + component_name = "HandlerModelResponseBody" if with_dto else "test_v2_computed_fields.Model" + + with create_test_client([handler]) as client: + schema = client.app.openapi_schema.components.schemas[component_name] + res = client.get("/") + assert list(schema.properties.keys()) == ["foo", "bar", "baz"] + assert schema.properties["baz"].title == "this is computed" + assert schema.properties["baz"].examples == [1] + assert res.json() == {"foo": 1, "bar": 2, "baz": 3} diff --git a/tests/unit/test_contrib/test_pydantic/test_openapi.py b/tests/unit/test_contrib/test_pydantic/test_openapi.py index b0921f00ac..a176a5b82b 100644 --- a/tests/unit/test_contrib/test_pydantic/test_openapi.py +++ b/tests/unit/test_contrib/test_pydantic/test_openapi.py @@ -1,26 +1,21 @@ -from datetime import date, datetime, timedelta, timezone +# pyright: reportOptionalSubscript=false, reportGeneralTypeIssues=false +from datetime import date, timedelta from decimal import Decimal from types import ModuleType -from typing import Any, Callable, Pattern, Type, Union, cast +from typing import Any, Callable, Dict, Optional, Pattern, Type, Union, cast +import annotated_types import pydantic as pydantic_v2 import pytest from pydantic import v1 as pydantic_v1 from typing_extensions import Annotated -from litestar import Litestar, post -from litestar._openapi.schema_generation.constrained_fields import ( - create_date_constrained_field_schema, - create_numerical_constrained_field_schema, - create_string_constrained_field_schema, -) +from litestar import Litestar, get, post from litestar._openapi.schema_generation.schema import SchemaCreator from litestar.contrib.pydantic import PydanticPlugin, PydanticSchemaPlugin from litestar.openapi import OpenAPIConfig from litestar.openapi.spec import Reference, Schema from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType -from litestar.params import KwargDefinition -from litestar.status_codes import HTTP_200_OK from litestar.testing import TestClient, create_test_client from litestar.typing import FieldDefinition from litestar.utils import is_class_and_subclass @@ -60,6 +55,9 @@ pydantic_v2.constr(min_length=1), pydantic_v2.constr(min_length=10), pydantic_v2.constr(min_length=10, max_length=100), +] + +constrained_bytes_v2 = [ pydantic_v2.conbytes(min_length=1), pydantic_v2.conbytes(min_length=10), pydantic_v2.conbytes(min_length=10, max_length=100), @@ -124,23 +122,59 @@ ] +@pytest.fixture() +def schema_creator(plugin: PydanticSchemaPlugin) -> SchemaCreator: + return SchemaCreator(plugins=[plugin]) + + +@pytest.fixture() +def plugin() -> PydanticSchemaPlugin: + return PydanticSchemaPlugin() + + @pytest.mark.parametrize("annotation", constrained_collection_v1) -def test_create_collection_constrained_field_schema_pydantic_v1(annotation: Any) -> None: - schema = SchemaCreator().for_collection_constrained_field(FieldDefinition.from_annotation(annotation)) +def test_create_collection_constrained_field_schema_pydantic_v1( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v1.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] + assert schema.type == OpenAPIType.ARRAY assert schema.items.type == OpenAPIType.INTEGER # type: ignore[union-attr] assert schema.min_items == annotation.min_items assert schema.max_items == annotation.max_items -@pytest.mark.parametrize("annotation", constrained_collection_v2) -def test_create_collection_constrained_field_schema_pydantic_v2(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) - schema = SchemaCreator().for_collection_constrained_field(field_definition) +@pytest.mark.parametrize("make_constraint", [pydantic_v2.conlist, pydantic_v2.conset, pydantic_v2.confrozenset]) +@pytest.mark.parametrize( + "min_length, max_length", + [ + (None, None), + (1, None), + (1, 1), + (None, 1), + ], +) +def test_create_collection_constrained_field_schema_pydantic_v2( + make_constraint: Callable[..., Any], + min_length: Optional[int], + max_length: Optional[int], + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v2.BaseModel): + field: make_constraint(int, min_length=min_length, max_length=max_length) # type: ignore[valid-type] + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] + assert schema.type == OpenAPIType.ARRAY assert schema.items.type == OpenAPIType.INTEGER # type: ignore[union-attr] - assert any(getattr(m, "min_length", None) == schema.min_items for m in field_definition.metadata if m) - assert any(getattr(m, "max_length", None) == schema.max_items for m in field_definition.metadata if m) + assert schema.min_items == min_length + assert schema.max_items == max_length @pytest.fixture() @@ -154,39 +188,59 @@ def conlist(pydantic_version: PydanticVersion) -> Any: def test_create_collection_constrained_field_schema_sub_fields( - pydantic_version: PydanticVersion, conset: Any, conlist: Any + pydantic_version: PydanticVersion, + conset: Any, + conlist: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, ) -> None: - for pydantic_fn in [conset, conlist]: - if pydantic_version == "v1": - annotation = pydantic_fn(Union[str, int], min_items=1, max_items=10) - else: - annotation = pydantic_fn(Union[str, int], min_length=1, max_length=10) - field_definition = FieldDefinition.from_annotation(annotation) - schema = SchemaCreator().for_collection_constrained_field(field_definition) + if pydantic_version == "v1": + + class Modelv1(pydantic_v1.BaseModel): + set_field: conset(Union[str, int], min_items=1, max_items=10) # type: ignore[valid-type] + list_field: conlist(Union[str, int], min_items=1, max_items=10) # type: ignore[valid-type] + + model_schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Modelv1), plugin) + else: + + class Modelv2(pydantic_v2.BaseModel): + set_field: conset(Union[str, int], min_length=1, max_length=10) # type: ignore[valid-type] + list_field: conlist(Union[str, int], min_length=1, max_length=10) # type: ignore[valid-type] + + model_schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Modelv2), plugin) + + def _get_schema_type(s: Any) -> OpenAPIType: + assert isinstance(s, Schema) + assert isinstance(s.type, OpenAPIType) + return s.type + + for field_name in ["set_field", "list_field"]: + schema = model_schema.properties[field_name] + assert schema.type == OpenAPIType.ARRAY assert schema.max_items == 10 assert schema.min_items == 1 assert isinstance(schema.items, Schema) assert schema.items.one_of is not None - def _get_schema_type(s: Any) -> OpenAPIType: - assert isinstance(s, Schema) - assert isinstance(s.type, OpenAPIType) - return s.type - # /~https://github.com/litestar-org/litestar/pull/2570#issuecomment-1788122570 assert {_get_schema_type(s) for s in schema.items.one_of} == {OpenAPIType.STRING, OpenAPIType.INTEGER} - if pydantic_fn is conset: - # set should have uniqueItems always - assert schema.unique_items + + # set should have uniqueItems always + assert model_schema.properties["set_field"].unique_items @pytest.mark.parametrize("annotation", constrained_string_v1) -def test_create_string_constrained_field_schema_pydantic_v1(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) +def test_create_string_constrained_field_schema_pydantic_v1( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v1.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_string_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert schema.type == OpenAPIType.STRING assert schema.min_length == annotation.min_length @@ -200,31 +254,61 @@ def test_create_string_constrained_field_schema_pydantic_v1(annotation: Any) -> @pytest.mark.parametrize("annotation", constrained_string_v2) -def test_create_string_constrained_field_schema_pydantic_v2(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) +def test_create_string_constrained_field_schema_pydantic_v2( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + constraint: pydantic_v2.types.StringConstraints = annotation.__metadata__[0] + + class Model(pydantic_v2.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_string_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert schema.type == OpenAPIType.STRING + assert schema.min_length == constraint.min_length + assert schema.max_length == constraint.max_length + assert schema.pattern == constraint.pattern + if constraint.to_upper: + assert schema.description == "must be in upper case" + if constraint.to_lower: + assert schema.description == "must be in lower case" + + +@pytest.mark.parametrize("annotation", constrained_bytes_v2) +def test_create_byte_constrained_field_schema_pydantic_v2( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + constraint: annotated_types.Len = annotation.__metadata__[1] - assert any(getattr(m, "min_length", None) == schema.min_length for m in field_definition.metadata if m) - assert any(getattr(m, "max_length", None) == schema.max_length for m in field_definition.metadata if m) - if pattern := getattr(annotation, "regex", getattr(annotation, "pattern", None)): - assert schema.pattern == pattern.pattern if isinstance(pattern, Pattern) else pattern - if any(getattr(m, "to_lower", getattr(m, "to_upper", None)) for m in field_definition.metadata if m): - assert schema.description + class Model(pydantic_v2.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] + + assert schema.type == OpenAPIType.STRING + assert schema.min_length == constraint.min_length + assert schema.max_length == constraint.max_length @pytest.mark.parametrize("annotation", constrained_numbers_v1) -def test_create_numerical_constrained_field_schema_pydantic_v1(annotation: Any) -> None: +def test_create_numerical_constrained_field_schema_pydantic_v1( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt annotation = cast(Union[ConstrainedInt, ConstrainedFloat, ConstrainedDecimal], annotation) - field_definition = FieldDefinition.from_annotation(annotation) + class Model(pydantic_v1.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_numerical_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert ( schema.type == OpenAPIType.INTEGER if is_class_and_subclass(annotation, ConstrainedInt) else OpenAPIType.NUMBER ) @@ -235,102 +319,95 @@ def test_create_numerical_constrained_field_schema_pydantic_v1(annotation: Any) assert schema.multiple_of == annotation.multiple_of -@pytest.mark.parametrize("annotation", constrained_numbers_v2) -def test_create_numerical_constrained_field_schema_pydantic_v2(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) +@pytest.mark.parametrize( + "make_constraint, constraint_kwargs", + [ + (pydantic_v2.conint, {"gt": 10, "lt": 100}), + (pydantic_v2.conint, {"ge": 10, "le": 100}), + (pydantic_v2.conint, {"ge": 10, "le": 100, "multiple_of": 7}), + (pydantic_v2.confloat, {"gt": 10, "lt": 100}), + (pydantic_v2.confloat, {"ge": 10, "le": 100}), + (pydantic_v2.confloat, {"ge": 10, "le": 100, "multiple_of": 4.2}), + (pydantic_v2.confloat, {"gt": 10, "lt": 100, "multiple_of": 10}), + (pydantic_v2.condecimal, {"gt": Decimal("10"), "lt": Decimal("100")}), + (pydantic_v2.condecimal, {"ge": Decimal("10"), "le": Decimal("100")}), + (pydantic_v2.condecimal, {"gt": Decimal("10"), "lt": Decimal("100"), "multiple_of": Decimal("5")}), + (pydantic_v2.condecimal, {"ge": Decimal("10"), "le": Decimal("100"), "multiple_of": Decimal("2")}), + ], +) +def test_create_numerical_constrained_field_schema_pydantic_v2( + make_constraint: Any, + constraint_kwargs: Dict[str, Any], + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + annotation = make_constraint(**constraint_kwargs) + + class Model(pydantic_v1.BaseModel): + field: annotation # type: ignore[valid-type] + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_numerical_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert schema.type == OpenAPIType.INTEGER if is_class_and_subclass(annotation, int) else OpenAPIType.NUMBER - assert any(getattr(m, "gt", None) == schema.exclusive_minimum for m in field_definition.metadata if m) - assert any(getattr(m, "ge", None) == schema.minimum for m in field_definition.metadata if m) - assert any(getattr(m, "lt", None) == schema.exclusive_maximum for m in field_definition.metadata if m) - assert any(getattr(m, "le", None) == schema.maximum for m in field_definition.metadata if m) - assert any(getattr(m, "multiple_of", None) == schema.multiple_of for m in field_definition.metadata if m) + assert schema.exclusive_minimum == constraint_kwargs.get("gt") + assert schema.minimum == constraint_kwargs.get("ge") + assert schema.exclusive_maximum == constraint_kwargs.get("lt") + assert schema.maximum == constraint_kwargs.get("le") + assert schema.multiple_of == constraint_kwargs.get("multiple_of") @pytest.mark.parametrize("annotation", constrained_dates_v1) -def test_create_date_constrained_field_schema_pydantic_v1(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) +def test_create_date_constrained_field_schema_pydantic_v1( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v1.BaseModel): + field: annotation + + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_date_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert schema.type == OpenAPIType.STRING assert schema.format == OpenAPIFormat.DATE - assert ( - datetime.fromtimestamp(schema.exclusive_minimum, tz=timezone.utc) if schema.exclusive_minimum else None - ) == ( - datetime.fromordinal(annotation.gt.toordinal()).replace(tzinfo=timezone.utc) - if annotation.gt is not None - else None - ) - assert (datetime.fromtimestamp(schema.minimum, tz=timezone.utc) if schema.minimum else None) == ( - datetime.fromordinal(annotation.ge.toordinal()).replace(tzinfo=timezone.utc) - if annotation.ge is not None - else None - ) - assert ( - datetime.fromtimestamp(schema.exclusive_maximum, tz=timezone.utc) if schema.exclusive_maximum else None - ) == ( - datetime.fromordinal(annotation.lt.toordinal()).replace(tzinfo=timezone.utc) - if annotation.lt is not None - else None - ) - assert (datetime.fromtimestamp(schema.maximum, tz=timezone.utc) if schema.maximum else None) == ( - datetime.fromordinal(annotation.le.toordinal()).replace(tzinfo=timezone.utc) - if annotation.le is not None - else None - ) + if gt := annotation.gt: + assert date.fromtimestamp(schema.exclusive_minimum) == gt # type: ignore[arg-type] + if ge := annotation.ge: + assert date.fromtimestamp(schema.minimum) == ge # type: ignore[arg-type] + if lt := annotation.lt: + assert date.fromtimestamp(schema.exclusive_maximum) == lt # type: ignore[arg-type] + if le := annotation.le: + assert date.fromtimestamp(schema.maximum) == le # type: ignore[arg-type] -@pytest.mark.parametrize("annotation", constrained_dates_v2) -def test_create_date_constrained_field_schema_pydantic_v2(annotation: Any) -> None: - field_definition = FieldDefinition.from_annotation(annotation) +@pytest.mark.parametrize( + "constraints", + [ + {"gt": date.today() - timedelta(days=10), "lt": date.today() + timedelta(days=100)}, + {"ge": date.today() - timedelta(days=10), "le": date.today() + timedelta(days=100)}, + {"gt": date.today() - timedelta(days=10), "lt": date.today() + timedelta(days=100)}, + {"ge": date.today() - timedelta(days=10), "le": date.today() + timedelta(days=100)}, + ], +) +def test_create_date_constrained_field_schema_pydantic_v2( + constraints: Dict[str, Any], + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v2.BaseModel): + field: pydantic_v2.condate(**constraints) # type: ignore[valid-type] - assert isinstance(field_definition.kwarg_definition, KwargDefinition) - schema = create_date_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) + schema = schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] assert schema.type == OpenAPIType.STRING assert schema.format == OpenAPIFormat.DATE - assert any( - ( - datetime.fromordinal(getattr(m, "gt", None).toordinal()).replace(tzinfo=timezone.utc) # type: ignore[union-attr] - if getattr(m, "gt", None) is not None - else None - ) - == (datetime.fromtimestamp(schema.exclusive_minimum, tz=timezone.utc) if schema.exclusive_minimum else None) - for m in field_definition.metadata - if m - ) - assert any( - ( - datetime.fromordinal(getattr(m, "ge", None).toordinal()).replace(tzinfo=timezone.utc) # type: ignore[union-attr] - if getattr(m, "ge", None) is not None - else None - ) - == (datetime.fromtimestamp(schema.minimum, tz=timezone.utc) if schema.minimum else None) - for m in field_definition.metadata - if m - ) - assert any( - ( - datetime.fromordinal(getattr(m, "lt", None).toordinal()).replace(tzinfo=timezone.utc) # type: ignore[union-attr] - if getattr(m, "lt", None) is not None - else None - ) - == (datetime.fromtimestamp(schema.exclusive_maximum, tz=timezone.utc) if schema.exclusive_maximum else None) - for m in field_definition.metadata - if m - ) - assert any( - ( - datetime.fromordinal(getattr(m, "le", None).toordinal()).replace(tzinfo=timezone.utc) # type: ignore[union-attr] - if getattr(m, "le", None) is not None - else None - ) - == (datetime.fromtimestamp(schema.maximum, tz=timezone.utc) if schema.maximum else None) - for m in field_definition.metadata - if m - ) + + if gt := constraints.get("gt"): + assert date.fromtimestamp(schema.exclusive_minimum) == gt # type: ignore[arg-type] + if ge := constraints.get("ge"): + assert date.fromtimestamp(schema.minimum) == ge # type: ignore[arg-type] + if lt := constraints.get("lt"): + assert date.fromtimestamp(schema.exclusive_maximum) == lt # type: ignore[arg-type] + if le := constraints.get("le"): + assert date.fromtimestamp(schema.maximum) == le # type: ignore[arg-type] @pytest.mark.parametrize( @@ -340,15 +417,37 @@ def test_create_date_constrained_field_schema_pydantic_v2(annotation: Any) -> No *constrained_collection_v1, *constrained_string_v1, *constrained_dates_v1, + ], +) +def test_create_constrained_field_schema_v1( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v1.BaseModel): + field: annotation + + assert schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] + + +@pytest.mark.parametrize( + "annotation", + [ *constrained_numbers_v2, *constrained_collection_v2, *constrained_string_v2, *constrained_dates_v2, ], ) -def test_create_constrained_field_schema(annotation: Any) -> None: - schema = SchemaCreator().for_constrained_field(FieldDefinition.from_annotation(annotation)) - assert schema +def test_create_constrained_field_schema_v2( + annotation: Any, + schema_creator: SchemaCreator, + plugin: PydanticSchemaPlugin, +) -> None: + class Model(pydantic_v2.BaseModel): + field: annotation + + assert schema_creator.for_plugin(FieldDefinition.from_annotation(Model), plugin).properties["field"] # type: ignore[index, union-attr] @pytest.mark.parametrize("cls", (PydanticPerson, PydanticDataclassPerson, PydanticV1Person, PydanticV1DataclassPerson)) @@ -391,8 +490,7 @@ def handler(data: cls) -> cls: } -@pytest.mark.parametrize("create_examples", (True, False)) -def test_schema_generation_v1(create_examples: bool) -> None: +def test_schema_generation_v1() -> None: class Lookup(pydantic_v1.BaseModel): id: Annotated[ str, @@ -401,35 +499,30 @@ class Lookup(pydantic_v1.BaseModel): max_length=16, description="A unique identifier", example="e4eaaaf2-d142-11e1-b3e4-080027620cdd", # pyright: ignore + examples=["31", "32"], ), ] + with_title: str = pydantic_v1.Field(title="WITH_title") @post("/example") async def example_route() -> Lookup: - return Lookup(id="1234567812345678") - - with create_test_client( - route_handlers=[example_route], - openapi_config=OpenAPIConfig( - title="Example API", - version="1.0.0", - create_examples=create_examples, - ), - signature_namespace={"Lookup": Lookup}, - ) as client: - response = client.get("/schema/openapi.json") - assert response.status_code == HTTP_200_OK - assert response.json()["components"]["schemas"]["test_schema_generation_v1.Lookup"]["properties"]["id"] == { - "description": "A unique identifier", - "examples": ["e4eaaaf2-d142-11e1-b3e4-080027620cdd"], - "maxLength": 16, - "minLength": 12, - "type": "string", - } + return Lookup(id="1234567812345678", with_title="1") + + app = Litestar([example_route]) + schema = app.openapi_schema.to_schema() + lookup_schema = schema["components"]["schemas"]["test_schema_generation_v1.Lookup"]["properties"] + + assert lookup_schema["id"] == { + "description": "A unique identifier", + "examples": ["e4eaaaf2-d142-11e1-b3e4-080027620cdd", "31", "32"], + "maxLength": 16, + "minLength": 12, + "type": "string", + } + assert lookup_schema["with_title"] == {"title": "WITH_title", "type": "string"} -@pytest.mark.parametrize("create_examples", (True, False)) -def test_schema_generation_v2(create_examples: bool) -> None: +def test_schema_generation_v2() -> None: class Lookup(pydantic_v2.BaseModel): id: Annotated[ str, @@ -437,32 +530,72 @@ class Lookup(pydantic_v2.BaseModel): min_length=12, max_length=16, description="A unique identifier", - json_schema_extra={"example": "e4eaaaf2-d142-11e1-b3e4-080027620cdd"}, + # we expect these examples to be merged + json_schema_extra={"example": "e4eaaaf2-d142-11e1-b3e4-080027620cdd", "examples": ["31"]}, + examples=["32"], ), ] + # title should work if given on the field + with_title: str = pydantic_v2.Field(title="WITH_title") + # or as an extra + with_extra_title: str = pydantic_v2.Field(json_schema_extra={"title": "WITH_extra"}) @post("/example") async def example_route() -> Lookup: - return Lookup(id="1234567812345678") + return Lookup(id="1234567812345678", with_title="1", with_extra_title="2") + + app = Litestar([example_route]) + schema = app.openapi_schema.to_schema() + lookup_schema = schema["components"]["schemas"]["test_schema_generation_v2.Lookup"]["properties"] + + assert lookup_schema["id"] == { + "description": "A unique identifier", + "examples": ["e4eaaaf2-d142-11e1-b3e4-080027620cdd", "31", "32"], + "maxLength": 16, + "minLength": 12, + "type": "string", + } + assert lookup_schema["with_title"] == {"title": "WITH_title", "type": "string"} + assert lookup_schema["with_extra_title"] == {"title": "WITH_extra", "type": "string"} - with create_test_client( - route_handlers=[example_route], + +def test_create_examples(pydantic_version: PydanticVersion) -> None: + lib = pydantic_v1 if pydantic_version == "v1" else pydantic_v2 + + class Model(lib.BaseModel): # type: ignore[name-defined, misc] + foo: str = lib.Field(examples=["32"]) + bar: str + + @get("/example") + async def handler() -> Model: + return Model(foo="1", bar="2") + + app = Litestar( + [handler], openapi_config=OpenAPIConfig( - title="Example API", - version="1.0.0", - create_examples=create_examples, + title="Test", + version="0", + create_examples=True, ), - signature_namespace={"Lookup": Lookup}, - ) as client: - response = client.get("/schema/openapi.json") - assert response.status_code == HTTP_200_OK - assert response.json()["components"]["schemas"]["test_schema_generation_v2.Lookup"]["properties"]["id"] == { - "description": "A unique identifier", - "examples": ["e4eaaaf2-d142-11e1-b3e4-080027620cdd"], - "maxLength": 16, - "minLength": 12, - "type": "string", - } + ) + schema = app.openapi_schema.to_schema() + lookup_schema = schema["components"]["schemas"]["test_create_examples.Model"]["properties"] + + assert lookup_schema["foo"]["examples"] == ["32"] + assert lookup_schema["bar"]["examples"] + + +def test_v2_json_schema_extra_callable_raises() -> None: + class Model(pydantic_v2.BaseModel): + field: str = pydantic_v2.Field(json_schema_extra=lambda e: None) + + @get("/example") + def handler() -> Model: + return Model(field="1") + + app = Litestar([handler]) + with pytest.raises(ValueError, match="Callables not supported"): + app.openapi_schema def test_schema_by_alias(base_model: AnyBaseModelType, pydantic_version: PydanticVersion) -> None: @@ -581,7 +714,7 @@ class Model(pydantic_v2.BaseModel): assert value.examples == ["example"] -def test_create_schema_for_field_v2__examples() -> None: +def test_create_schema_for_field_v2_examples() -> None: class Model(pydantic_v2.BaseModel): value: str = pydantic_v2.Field( title="title", description="description", max_length=16, json_schema_extra={"examples": ["example"]} diff --git a/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py b/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py index ee2ecdc467..e7f840e244 100644 --- a/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py +++ b/tests/unit/test_contrib/test_pydantic/test_pydantic_dto_factory.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import replace +from typing import TYPE_CHECKING, Callable, Optional +from unittest.mock import ANY import pydantic as pydantic_v2 import pytest @@ -8,8 +10,7 @@ from typing_extensions import Annotated from litestar.contrib.pydantic import PydanticDTO -from litestar.dto import DTOField, dto_field -from litestar.dto.data_structures import DTOFieldDefinition +from litestar.dto import DTOField, DTOFieldDefinition, Mark, dto_field from litestar.typing import FieldDefinition from . import PydanticVersion @@ -18,6 +19,88 @@ from typing import Callable +@pytest.fixture +def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefinition]: + return [ + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="a", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + passthrough_constraints=False, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="b", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(mark=Mark.READ_ONLY), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + passthrough_constraints=False, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="c", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + passthrough_constraints=False, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=int, + name="d", + default=1, + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + ), + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + passthrough_constraints=False, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=Optional[int], + name="e", + ), + model_name=ANY, + default_factory=int_factory, + dto_field=DTOField(), + ), + default=None, + metadata=ANY, + type_wrappers=ANY, + raw=ANY, + kwarg_definition=ANY, + passthrough_constraints=False, + ), + ] + + def test_field_definition_generation_v1( int_factory: Callable[[], int], expected_field_defs: list[DTOFieldDefinition], diff --git a/tests/unit/test_contrib/test_pydantic/test_schema_plugin.py b/tests/unit/test_contrib/test_pydantic/test_schema_plugin.py index b60d01840a..88c463febe 100644 --- a/tests/unit/test_contrib/test_pydantic/test_schema_plugin.py +++ b/tests/unit/test_contrib/test_pydantic/test_schema_plugin.py @@ -8,6 +8,7 @@ from pydantic.v1.generics import GenericModel from typing_extensions import Annotated +from litestar import Litestar, post from litestar._openapi.schema_generation import SchemaCreator from litestar.contrib.pydantic.pydantic_schema_plugin import PydanticSchemaPlugin from litestar.openapi.spec import OpenAPIType @@ -127,3 +128,21 @@ def test_exclude_private_fields(model_class: Type[Union[pydantic_v1.BaseModel, p FieldDefinition.from_annotation(model_class), schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()]) ) assert not schema.properties + + +def test_v1_constrained_str_with_default_factory_does_not_generate_title() -> None: + # /~https://github.com/litestar-org/litestar/issues/3710 + class Model(pydantic_v1.BaseModel): + test_str: str = pydantic_v1.Field(default_factory=str, max_length=600) + + @post(path="/") + async def test(data: Model) -> str: + return "success" + + schema = Litestar(route_handlers=[test]).openapi_schema.to_schema() + assert ( + "title" + not in schema["components"]["schemas"][ + "test_v1_constrained_str_with_default_factory_does_not_generate_title.Model" + ]["properties"]["test_str"]["oneOf"][1] + ) diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index daa57ad4c0..3e15b9f51d 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -296,16 +296,27 @@ class Foo(TypedDict): def test_create_schema_from_msgspec_annotated_type() -> None: class Lookup(msgspec.Struct): - id: Annotated[str, msgspec.Meta(max_length=16, examples=["example"], description="description", title="title")] + str_field: Annotated[ + str, + msgspec.Meta(max_length=16, examples=["example"], description="description", title="title", pattern=r"\w+"), + ] + bytes_field: Annotated[bytes, msgspec.Meta(max_length=2, min_length=1)] + default_field: Annotated[str, msgspec.Meta(min_length=1)] = "a" schema = get_schema_for_field_definition(FieldDefinition.from_kwarg(name="Lookup", annotation=Lookup)) - assert schema.properties["id"].type == OpenAPIType.STRING # type: ignore[index, union-attr] - assert schema.properties["id"].examples == ["example"] # type: ignore[index, union-attr] - assert schema.properties["id"].description == "description" # type: ignore[index] - assert schema.properties["id"].title == "title" # type: ignore[index, union-attr] - assert schema.properties["id"].max_length == 16 # type: ignore[index, union-attr] - assert schema.required == ["id"] + assert schema.properties["str_field"].type == OpenAPIType.STRING # type: ignore[index, union-attr] + assert schema.properties["str_field"].examples == ["example"] # type: ignore[index, union-attr] + assert schema.properties["str_field"].description == "description" # type: ignore[index] + assert schema.properties["str_field"].title == "title" # type: ignore[index, union-attr] + assert schema.properties["str_field"].max_length == 16 # type: ignore[index, union-attr] + assert sorted(schema.required) == sorted(["str_field", "bytes_field"]) # type: ignore[arg-type] + assert schema.properties["bytes_field"].to_schema() == { # type: ignore[index] + "contentEncoding": "utf-8", + "maxLength": 2, + "minLength": 1, + "type": "string", + } def test_annotated_types() -> None: diff --git a/tests/unit/test_typing.py b/tests/unit/test_typing.py index d0ebdff97f..bb2334f913 100644 --- a/tests/unit/test_typing.py +++ b/tests/unit/test_typing.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Any, ForwardRef, Generic, List, Optional, Tuple, TypeVar, Union -import annotated_types import msgspec import pytest from typing_extensions import Annotated, NotRequired, Required, TypeAliasType, TypedDict, get_type_hints @@ -12,7 +11,7 @@ from litestar import get from litestar.exceptions import LitestarWarning from litestar.params import DependencyKwarg, KwargDefinition, Parameter, ParameterKwarg -from litestar.typing import FieldDefinition, _unpack_predicate +from litestar.typing import FieldDefinition from .test_utils.test_signature import T, _check_field_definition, field_definition_int, test_type_hints @@ -440,20 +439,6 @@ def test_field_definition_get_type_hints_dont_resolve_generics( ) -@pytest.mark.parametrize( - "predicate, expected_meta", - [ - (annotated_types.LowerCase.__metadata__[0], {"lower_case": True}), # pyright: ignore - (annotated_types.UpperCase.__metadata__[0], {"upper_case": True}), # pyright: ignore - (annotated_types.IsAscii.__metadata__[0], {"pattern": "[[:ascii:]]"}), # pyright: ignore - (annotated_types.IsDigits.__metadata__[0], {"pattern": "[[:digit:]]"}), # pyright: ignore - (object(), {}), - ], -) -def test_unpack_predicate(predicate: Any, expected_meta: dict[str, Any]) -> None: - assert _unpack_predicate(predicate) == expected_meta - - def test_warn_ambiguous_default_values() -> None: with pytest.warns(LitestarWarning, match="Ambiguous default values"): FieldDefinition.from_annotation(Annotated[int, Parameter(default=1)], default=2)