Skip to content

Commit

Permalink
fix(DTO): Inconsistent use of strict mode (#3685)
Browse files Browse the repository at this point in the history
Fix strict encoding when parsing raw
  • Loading branch information
provinzkraut authored Aug 22, 2024
1 parent 238d26b commit b058d64
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
4 changes: 2 additions & 2 deletions litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def parse_raw(self, raw: bytes, asgi_connection: ASGIConnection) -> Struct | Col
type_decoders = asgi_connection.route_handler.resolve_type_decoders()

if request_encoding == RequestEncodingType.MESSAGEPACK:
result = decode_msgpack(value=raw, target_type=self.annotation, type_decoders=type_decoders)
result = decode_msgpack(value=raw, target_type=self.annotation, type_decoders=type_decoders, strict=False)
else:
result = decode_json(value=raw, target_type=self.annotation, type_decoders=type_decoders)
result = decode_json(value=raw, target_type=self.annotation, type_decoders=type_decoders, strict=False)

return cast("Struct | Collection[Struct]", result)

Expand Down
39 changes: 29 additions & 10 deletions litestar/serialization/msgspec_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,32 +165,37 @@ def encode_json(value: Any, serializer: Callable[[Any], Any] | None = None) -> b


@overload
def decode_json(value: str | bytes) -> Any: ...
def decode_json(value: str | bytes, strict: bool = ...) -> Any: ...


@overload
def decode_json(value: str | bytes, type_decoders: TypeDecodersSequence | None) -> Any: ...
def decode_json(value: str | bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ...


@overload
def decode_json(value: str | bytes, target_type: type[T]) -> T: ...
def decode_json(value: str | bytes, target_type: type[T], strict: bool = ...) -> T: ...


@overload
def decode_json(value: str | bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None) -> T: ...
def decode_json(
value: str | bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ...
) -> T: ...


def decode_json( # type: ignore[misc]
value: str | bytes,
target_type: type[T] | EmptyType = Empty, # pyright: ignore
type_decoders: TypeDecodersSequence | None = None,
strict: bool = True,
) -> Any:
"""Decode a JSON string/bytes into an object.
Args:
value: Value to decode
target_type: An optional type to decode the data into
type_decoders: Optional sequence of type decoders
strict: Whether type coercion rules should be strict. Setting to False enables
a wider set of coercion rules from string to non-string types for all values
Returns:
An object
Expand All @@ -202,7 +207,13 @@ def decode_json( # type: ignore[misc]
if target_type is Empty:
return _msgspec_json_decoder.decode(value)
return msgspec.json.decode(
value, dec_hook=partial(default_deserializer, type_decoders=type_decoders), type=target_type
value,
dec_hook=partial(
default_deserializer,
type_decoders=type_decoders,
),
type=target_type,
strict=strict,
)
except msgspec.DecodeError as msgspec_error:
raise SerializationException(str(msgspec_error)) from msgspec_error
Expand Down Expand Up @@ -230,32 +241,37 @@ def encode_msgpack(value: Any, serializer: Callable[[Any], Any] | None = default


@overload
def decode_msgpack(value: bytes) -> Any: ...
def decode_msgpack(value: bytes, strict: bool = ...) -> Any: ...


@overload
def decode_msgpack(value: bytes, type_decoders: TypeDecodersSequence | None) -> Any: ...
def decode_msgpack(value: bytes, type_decoders: TypeDecodersSequence | None, strict: bool = ...) -> Any: ...


@overload
def decode_msgpack(value: bytes, target_type: type[T]) -> T: ...
def decode_msgpack(value: bytes, target_type: type[T], strict: bool = ...) -> T: ...


@overload
def decode_msgpack(value: bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None) -> T: ...
def decode_msgpack(
value: bytes, target_type: type[T], type_decoders: TypeDecodersSequence | None, strict: bool = ...
) -> T: ...


def decode_msgpack( # type: ignore[misc]
value: bytes,
target_type: type[T] | EmptyType = Empty, # pyright: ignore[reportInvalidTypeVarUse]
type_decoders: TypeDecodersSequence | None = None,
strict: bool = True,
) -> Any:
"""Decode a MessagePack string/bytes into an object.
Args:
value: Value to decode
target_type: An optional type to decode the data into
type_decoders: Optional sequence of type decoders
strict: Whether type coercion rules should be strict. Setting to False enables
a wider set of coercion rules from string to non-string types for all values
Returns:
An object
Expand All @@ -267,7 +283,10 @@ def decode_msgpack( # type: ignore[misc]
if target_type is Empty:
return _msgspec_msgpack_decoder.decode(value)
return msgspec.msgpack.decode(
value, dec_hook=partial(default_deserializer, type_decoders=type_decoders), type=target_type
value,
dec_hook=partial(default_deserializer, type_decoders=type_decoders),
type=target_type,
strict=strict,
)
except msgspec.DecodeError as msgspec_error:
raise SerializationException(str(msgspec_error)) from msgspec_error
Expand Down
18 changes: 8 additions & 10 deletions tests/unit/test_dto/test_factory/test_backends/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class DC:
nested: NestedDC
nested_list: List[NestedDC]
nested_mapping: Dict[str, NestedDC]
integer: int
b: str = field(default="b")
c: List[int] = field(default_factory=list)
optional: Optional[str] = None
Expand All @@ -53,10 +54,12 @@ class DC:
"nested": {"a": 1, "b": "two"},
"nested_list": [{"a": 1, "b": "two"}],
"nested_mapping": {"a": {"a": 1, "b": "two"}},
"integer": 1,
"optional": None,
}
RAW = b'{"a":1,"nested":{"a":1,"b":"two"},"nested_list":[{"a":1,"b":"two"}],"nested_mapping":{"a":{"a":1,"b":"two"}},"b":"b","c":[],"optional":null}'
COLLECTION_RAW = b'[{"a":1,"nested":{"a":1,"b":"two"},"nested_list":[{"a":1,"b":"two"}],"nested_mapping":{"a":{"a":1,"b":"two"}},"b":"b","c":[],"optional":null}]'
RAW = b'{"a":1,"nested":{"a":1,"b":"two"},"nested_list":[{"a":1,"b":"two"}],"nested_mapping":{"a":{"a":1,"b":"two"}},"integer":1,"b":"b","c":[],"optional":null}'
MSGPACK_RAW = b"\x88\xa1a\x01\xa6nested\x82\xa1a\x01\xa1b\xa3two\xabnested_list\x91\x82\xa1a\x01\xa1b\xa3two\xaenested_mapping\x81\xa1a\x82\xa1a\x01\xa1b\xa3two\xa7integer\x01\xa1b\xa1b\xa1c\x90\xa8optional\xc0"
COLLECTION_RAW = b'[{"a":1,"nested":{"a":1,"b":"two"},"nested_list":[{"a":1,"b":"two"}],"nested_mapping":{"a":{"a":1,"b":"two"}},"integer":1,"b":"b","c":[],"optional":null}]'
STRUCTURED = DC(
a=1,
b="b",
Expand All @@ -65,6 +68,7 @@ class DC:
nested_list=[NestedDC(a=1, b="two")],
nested_mapping={"a": NestedDC(a=1, b="two")},
optional=None,
integer=1,
)


Expand Down Expand Up @@ -97,10 +101,7 @@ def test_backend_parse_raw_json(
wrapper_attribute_name=None,
is_data_field=True,
handler_id="test",
).parse_raw(
b'{"a":1,"nested":{"a":1,"b":"two"},"nested_list":[{"a":1,"b":"two"}],"nested_mapping":{"a":{"a":1,"b":"two"}}}',
asgi_connection,
)
).parse_raw(RAW, asgi_connection)
)
== DESTRUCTURED
)
Expand All @@ -122,10 +123,7 @@ def _handler() -> None: ...
wrapper_attribute_name=None,
is_data_field=True,
handler_id="test",
).parse_raw(
b"\x87\xa1a\x01\xa6nested\x82\xa1a\x01\xa1b\xa3two\xabnested_list\x91\x82\xa1a\x01\xa1b\xa3two\xaenested_mapping\x81\xa1a\x82\xa1a\x01\xa1b\xa3two\xa1b\xa1b\xa1c\x90\xa8optional\xc0",
asgi_connection,
)
).parse_raw(MSGPACK_RAW, asgi_connection)
)
== DESTRUCTURED
)
Expand Down

0 comments on commit b058d64

Please sign in to comment.