Skip to content

Commit

Permalink
feat: Allow customizing schema component keys (#3738)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Sep 15, 2024
1 parent 1132554 commit 5ed9eb0
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 108 deletions.
115 changes: 90 additions & 25 deletions litestar/_openapi/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,67 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, Iterator, Sequence
from typing import TYPE_CHECKING, Iterator, Sequence, _GenericAlias # type: ignore[attr-defined]

from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.spec import Reference, Schema
from litestar.params import KwargDefinition

if TYPE_CHECKING:
from litestar.openapi import OpenAPIConfig
from litestar.plugins import OpenAPISchemaPluginProtocol
from litestar.typing import FieldDefinition


def _longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]:
"""Find the longest common prefix of a list of tuples.
Args:
tuples_: A list of tuples to find the longest common prefix of.
Returns:
The longest common prefix of the tuples.
"""
prefix_ = tuples_[0]
for t in tuples_:
# Compare the current prefix with each tuple and shorten it
prefix_ = prefix_[: min(len(prefix_), len(t))]
for i in range(len(prefix_)):
if prefix_[i] != t[i]:
prefix_ = prefix_[:i]
break
return prefix_


def _get_component_key_override(field: FieldDefinition) -> str | None:
if (
(kwarg_definition := field.kwarg_definition)
and isinstance(kwarg_definition, KwargDefinition)
and (schema_key := kwarg_definition.schema_component_key)
):
return schema_key
return None


def _get_normalized_schema_key(field_definition: FieldDefinition) -> tuple[str, ...]:
"""Create a key for a type annotation.
The key should be a tuple such as ``("path", "to", "type", "TypeName")``.
Args:
field_definition: Field definition
Returns:
A tuple of strings.
"""
if override := _get_component_key_override(field_definition):
return (override,)

annotation = field_definition.annotation
module = getattr(annotation, "__module__", "")
name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__
name = name.replace(".<locals>.", ".")
return *module.split("."), name


class RegisteredSchema:
Expand Down Expand Up @@ -43,32 +96,63 @@ def __init__(self) -> None:
self._schema_key_map: dict[tuple[str, ...], RegisteredSchema] = {}
self._schema_reference_map: dict[int, RegisteredSchema] = {}
self._model_name_groups: defaultdict[str, list[RegisteredSchema]] = defaultdict(list)
self._component_type_map: dict[tuple[str, ...], FieldDefinition] = {}

def get_schema_for_key(self, key: tuple[str, ...]) -> Schema:
def get_schema_for_field_definition(self, field: FieldDefinition) -> Schema:
"""Get a registered schema by its key.
Args:
key: The key to the schema to get.
field: The field definition to get the schema for
Returns:
A RegisteredSchema object.
"""
key = _get_normalized_schema_key(field)
if key not in self._schema_key_map:
self._schema_key_map[key] = registered_schema = RegisteredSchema(key, Schema(), [])
self._model_name_groups[key[-1]].append(registered_schema)
self._component_type_map[key] = field
else:
if (existing_type := self._component_type_map[key]) != field:
raise ImproperlyConfiguredException(
f"Schema component keys must be unique. Cannot override existing key {'_'.join(key)!r} for type "
f"{existing_type.raw!r} with new type {field.raw!r}"
)
return self._schema_key_map[key].schema

def get_reference_for_key(self, key: tuple[str, ...]) -> Reference | None:
def get_reference_for_field_definition(self, field: FieldDefinition) -> Reference | None:
"""Get a reference to a registered schema by its key.
Args:
key: The key to the schema to get.
field: The field definition to get the reference for
Returns:
A Reference object.
"""
key = _get_normalized_schema_key(field)
if key not in self._schema_key_map:
return None

if (existing_type := self._component_type_map[key]) != field:
# TODO: This should check for strict equality, e.g. changes in type metadata
# However, this is currently not possible to do without breaking things, as
# we allow to define metadata on a type annotation in one place to be used
# for the same type in a different place, where that same type is *not*
# annotated with this metadata. The proper fix for this would be to e.g.
# inline DTO definitions when they are created at the handler level, as
# they won't be reused (they already generate a unique key), and create a
# more strict lookup policy for component schemas
msg = (
f"Schema component keys must be unique. While obtaining a reference for the type '{field.raw!r}', the "
f"generated key {'_'.join(key)!r} was already associated with a different type '{existing_type.raw!r}'. "
)
if key_override := _get_component_key_override(field): # pragma: no branch
# Currently, this can never not be true, however, in the future we might
# decide to do a stricter equality check as lined out above, in which
# case there can be other cases than overrides that cause this error
msg += f"Hint: Both types are defining a 'schema_component_key' with the value of {key_override!r}"
raise ImproperlyConfiguredException(msg)

registered_schema = self._schema_key_map[key]
reference = Reference(f"#/components/schemas/{'_'.join(key)}")
registered_schema.references.append(reference)
Expand Down Expand Up @@ -107,26 +191,7 @@ def remove_common_prefix(tuples: list[tuple[str, ...]]) -> list[tuple[str, ...]]
A list of tuples with the common prefix removed.
"""

def longest_common_prefix(tuples_: list[tuple[str, ...]]) -> tuple[str, ...]:
"""Find the longest common prefix of a list of tuples.
Args:
tuples_: A list of tuples to find the longest common prefix of.
Returns:
The longest common prefix of the tuples.
"""
prefix_ = tuples_[0]
for t in tuples_:
# Compare the current prefix with each tuple and shorten it
prefix_ = prefix_[: min(len(prefix_), len(t))]
for i in range(len(prefix_)):
if prefix_[i] != t[i]:
prefix_ = prefix_[:i]
break
return prefix_

prefix = longest_common_prefix(tuples)
prefix = _longest_common_prefix(tuples)
prefix_length = len(prefix)
return [t[prefix_length:] for t in tuples]

Expand Down
9 changes: 3 additions & 6 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
create_string_constrained_field_schema,
)
from litestar._openapi.schema_generation.utils import (
_get_normalized_schema_key,
_should_create_enum_schema,
_should_create_literal_schema,
_type_or_first_not_none_inner_type,
Expand Down Expand Up @@ -508,8 +507,7 @@ def for_plugin(self, field_definition: FieldDefinition, plugin: OpenAPISchemaPlu
Returns:
A schema instance.
"""
key = _get_normalized_schema_key(field_definition.annotation)
if (ref := self.schema_registry.get_reference_for_key(key)) is not None:
if (ref := self.schema_registry.get_reference_for_field_definition(field_definition)) is not None:
return ref

schema = plugin.to_openapi_schema(field_definition=field_definition, schema_creator=self)
Expand Down Expand Up @@ -612,8 +610,7 @@ def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schem
schema.examples = get_json_schema_formatted_examples(create_examples_for_field(field))

if schema.title and schema.type == OpenAPIType.OBJECT:
key = _get_normalized_schema_key(field.annotation)
return self.schema_registry.get_reference_for_key(key) or schema
return self.schema_registry.get_reference_for_field_definition(field) or schema
return schema

def create_component_schema(
Expand Down Expand Up @@ -644,7 +641,7 @@ def create_component_schema(
Returns:
A schema instance.
"""
schema = self.schema_registry.get_schema_for_key(_get_normalized_schema_key(type_.annotation))
schema = self.schema_registry.get_schema_for_field_definition(type_)
schema.title = title or _get_type_schema_name(type_)
schema.required = required
schema.type = openapi_type
Expand Down
20 changes: 1 addition & 19 deletions litestar/_openapi/schema_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping, _GenericAlias # type: ignore[attr-defined]
from typing import TYPE_CHECKING, Any, Mapping

from litestar.utils.helpers import get_name

Expand All @@ -15,7 +15,6 @@
"_type_or_first_not_none_inner_type",
"_should_create_enum_schema",
"_should_create_literal_schema",
"_get_normalized_schema_key",
)


Expand Down Expand Up @@ -83,23 +82,6 @@ def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:
)


def _get_normalized_schema_key(annotation: Any) -> tuple[str, ...]:
"""Create a key for a type annotation.
The key should be a tuple such as ``("path", "to", "type", "TypeName")``.
Args:
annotation: a type annotation
Returns:
A tuple of strings.
"""
module = getattr(annotation, "__module__", "")
name = str(annotation)[len(module) + 1 :] if isinstance(annotation, _GenericAlias) else annotation.__qualname__
name = name.replace(".<locals>.", ".")
return *module.split("."), name


def get_formatted_examples(field_definition: FieldDefinition, examples: Sequence[Example]) -> Mapping[str, Example]:
"""Format the examples into the OpenAPI schema format."""

Expand Down
13 changes: 13 additions & 0 deletions litestar/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class KwargDefinition:
.. versionadded:: 2.8.0
"""
schema_component_key: str | None = None
"""
Use as the key for the reference when creating a component for this type
.. versionadded:: 2.12.0
"""

@property
def is_constrained(self) -> bool:
Expand Down Expand Up @@ -195,6 +200,7 @@ def Parameter(
required: bool | None = None,
title: str | None = None,
schema_extra: dict[str, Any] | None = None,
schema_component_key: str | None = None,
) -> Any:
"""Create an extended parameter kwarg definition.
Expand Down Expand Up @@ -239,6 +245,8 @@ def Parameter(
schema.
.. versionadded:: 2.8.0
schema_component_key: Use this as the key for the reference when creating a component for this type
.. versionadded:: 2.12.0
"""
return ParameterKwarg(
annotation=annotation,
Expand All @@ -264,6 +272,7 @@ def Parameter(
max_length=max_length,
pattern=pattern,
schema_extra=schema_extra,
schema_component_key=schema_component_key,
)


Expand Down Expand Up @@ -308,6 +317,7 @@ def Body(
pattern: str | None = None,
title: str | None = None,
schema_extra: dict[str, Any] | None = None,
schema_component_key: str | None = None,
) -> Any:
"""Create an extended request body kwarg definition.
Expand Down Expand Up @@ -349,6 +359,8 @@ def Body(
schema.
.. versionadded:: 2.8.0
schema_component_key: Use this as the key for the reference when creating a component for this type
.. versionadded:: 2.12.0
"""
return BodyKwarg(
media_type=media_type,
Expand All @@ -371,6 +383,7 @@ def Body(
pattern=pattern,
multipart_form_part_limit=multipart_form_part_limit,
schema_extra=schema_extra,
schema_component_key=schema_component_key,
)


Expand Down
Loading

0 comments on commit 5ed9eb0

Please sign in to comment.