diff --git a/msgspec/_core.c b/msgspec/_core.c index e6b68fd4..40515c20 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -379,6 +379,7 @@ typedef struct { PyObject *typing_classvar; PyObject *typing_typevar; PyObject *typing_final; + PyObject *typing_generic; PyObject *typing_generic_alias; PyObject *typing_annotated_alias; PyObject *concrete_types; @@ -4963,7 +4964,7 @@ structmeta_get_module_ns(MsgspecState *mod, StructMetaInfo *info) { } static int -structmeta_collect_base(StructMetaInfo *info, PyObject *base) { +structmeta_collect_base(StructMetaInfo *info, MsgspecState *mod, PyObject *base) { if ((PyTypeObject *)base == &StructMixinType) return 0; if (((PyTypeObject *)base)->tp_weaklistoffset) { @@ -4983,17 +4984,21 @@ structmeta_collect_base(StructMetaInfo *info, PyObject *base) { } if (Py_TYPE(base) != &StructMetaType) { - PyTypeObject *cls = (PyTypeObject *)base; + info->has_non_struct_bases = true; + /* XXX: in Python 3.8 Generic defines __new__, but we can ignore it. + * This can be removed when Python 3.8 support is dropped */ + if (base == mod->typing_generic) return 0; + static const char *attrs[] = {"__init__", "__new__"}; Py_ssize_t nattrs = 2; - for (Py_ssize_t i = 0; i < nattrs; i++) { - if (PyDict_GetItemString(cls->tp_dict, attrs[i]) != NULL) { + if (PyDict_GetItemString( + ((PyTypeObject *)base)->tp_dict, attrs[i]) != NULL + ) { PyErr_Format(PyExc_TypeError, "Struct base classes cannot define %s", attrs[i]); return -1; } } - info->has_non_struct_bases = true; return 0; } @@ -5643,7 +5648,7 @@ StructMeta_new_inner( /* Extract info from base classes in reverse MRO order */ for (Py_ssize_t i = PyTuple_GET_SIZE(bases) - 1; i >= 0; i--) { PyObject *base = PyTuple_GET_ITEM(bases, i); - if (structmeta_collect_base(&info, base) < 0) goto cleanup; + if (structmeta_collect_base(&info, mod, base) < 0) goto cleanup; } /* Process configuration options */ @@ -6098,9 +6103,7 @@ StructInfo_Convert(PyObject *obj) { class->struct_info = info; } else { - if (PyObject_SetAttr(obj, mod->str___msgspec_cache__, (PyObject *)info) < 0) { - goto error; - } + if (PyObject_SetAttr(obj, mod->str___msgspec_cache__, (PyObject *)info) < 0) goto error; } cache_set = true; @@ -18528,6 +18531,7 @@ msgspec_clear(PyObject *m) Py_CLEAR(st->typing_classvar); Py_CLEAR(st->typing_typevar); Py_CLEAR(st->typing_final); + Py_CLEAR(st->typing_generic); Py_CLEAR(st->typing_generic_alias); Py_CLEAR(st->typing_annotated_alias); Py_CLEAR(st->concrete_types); @@ -18593,6 +18597,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg) Py_VISIT(st->typing_classvar); Py_VISIT(st->typing_typevar); Py_VISIT(st->typing_final); + Py_VISIT(st->typing_generic); Py_VISIT(st->typing_generic_alias); Py_VISIT(st->typing_annotated_alias); Py_VISIT(st->concrete_types); @@ -18805,6 +18810,7 @@ PyInit__core(void) SET_REF(typing_classvar, "ClassVar"); SET_REF(typing_typevar, "TypeVar"); SET_REF(typing_final, "Final"); + SET_REF(typing_generic, "Generic"); SET_REF(typing_generic_alias, "_GenericAlias"); Py_DECREF(temp_module); diff --git a/msgspec/_utils.py b/msgspec/_utils.py index 9b17e55a..c97dfeb7 100644 --- a/msgspec/_utils.py +++ b/msgspec/_utils.py @@ -1,7 +1,6 @@ # type: ignore import collections import sys -import types import typing try: @@ -35,6 +34,32 @@ def get_type_hints(obj): return _get_type_hints(obj, include_extras=True) +# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10. +# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to +# check is to try it and see. This check can be removed when we drop support +# for Python 3.10. +try: + typing.ForwardRef("Foo", is_class=True) +except TypeError: + + def _forward_ref(value): + return typing.ForwardRef(value, is_argument=False) + +else: + + def _forward_ref(value): + return typing.ForwardRef(value, is_argument=False, is_class=True) + + +def _apply_params(obj, mapping): + if params := getattr(obj, "__parameters__", None): + args = tuple(mapping.get(p, p) for p in params) + return obj[args] + elif isinstance(obj, typing.TypeVar): + return mapping.get(obj, obj) + return obj + + def _get_class_mro_and_typevar_mappings(obj): mapping = {} @@ -43,22 +68,26 @@ def _get_class_mro_and_typevar_mappings(obj): else: cls = obj.__origin__ - def inner(c): + def inner(c, scope): if isinstance(c, type): cls = c + new_scope = {} else: cls = c.__origin__ if cls in (object, typing.Generic): return if cls not in mapping: - mapping[cls] = dict(zip(cls.__parameters__, c.__args__)) + params = cls.__parameters__ + args = tuple(_apply_params(a, scope) for a in c.__args__) + assert len(params) == len(args) + mapping[cls] = new_scope = dict(zip(params, args)) if issubclass(cls, typing.Generic): bases = getattr(cls, "__orig_bases__", cls.__bases__) for b in bases: - inner(b) + inner(b, new_scope) - inner(obj) + inner(obj, {}) return cls.__mro__, mapping @@ -92,23 +121,16 @@ def get_class_annotations(obj): cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {}) ann = cls.__dict__.get("__annotations__", {}) - if isinstance(ann, types.GetSetDescriptorType): - ann = {} - for name, value in ann.items(): if name in hints: continue if value is None: value = type(None) - if isinstance(value, str): - value = typing.ForwardRef(value, is_argument=False, is_class=True) + elif isinstance(value, str): + value = _forward_ref(value) value = typing._eval_type(value, cls_locals, cls_globals) if mapping is not None: - if params := getattr(value, "__parameters__", None): - args = tuple(mapping.get(p, p) for p in params) - value = value[args] - elif isinstance(value, typing.TypeVar): - value = mapping.get(value, value) + value = _apply_params(value, mapping) hints[name] = value return hints diff --git a/tests/test_common.py b/tests/test_common.py index e35f7aa7..160ca95f 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -16,6 +16,7 @@ Deque, Dict, Final, + Generic, List, Literal, NamedTuple, @@ -23,6 +24,7 @@ Optional, Tuple, TypedDict, + TypeVar, Union, ) @@ -45,6 +47,8 @@ py310_plus = pytest.mark.skipif(not PY310, reason="3.10+ only") py311_plus = pytest.mark.skipif(not PY311, reason="3.11+ only") +T = TypeVar("T") + @pytest.fixture(params=["json", "msgpack"]) def proto(request): @@ -1065,6 +1069,163 @@ class Test2(msgspec.Struct, tag=True): assert frozenset(new) in cache +class TestGenericStruct: + def test_generic_struct_info_cached(self, proto): + class Ex(msgspec.Struct, Generic[T]): + x: T + + typ = Ex[int] + assert Ex[int] is typ + + dec = proto.Decoder(typ) + info = typ.__msgspec_cache__ + assert info is not None + assert sys.getrefcount(info) == 4 # info + attr + decoder + func call + dec2 = proto.Decoder(typ) + assert typ.__msgspec_cache__ is info + assert sys.getrefcount(info) == 5 + + del dec + del dec2 + assert sys.getrefcount(info) == 3 + + def test_generic_struct_invalid_types_not_cached(self, proto): + class Ex(msgspec.Struct, Generic[T]): + x: Union[List[T], Tuple[float]] + + for typ in [Ex, Ex[int]]: + for _ in range(2): + with pytest.raises(TypeError, match="not supported"): + proto.Decoder(typ) + + assert not hasattr(typ, "__msgspec_cache__") + + def test_msgspec_cache_overwritten(self, proto): + class Ex(msgspec.Struct, Generic[T]): + x: T + + typ = Ex[int] + typ.__msgspec_cache__ = 1 + + with pytest.raises(RuntimeError, match="__msgspec_cache__"): + proto.Decoder(typ) + + @pytest.mark.parametrize("array_like", [False, True]) + def test_generic_struct(self, proto, array_like): + class Ex(msgspec.Struct, Generic[T], array_like=array_like): + x: T + y: List[T] + + sol = Ex(1, [1, 2]) + msg = proto.encode(sol) + + res = proto.decode(msg, type=Ex) + assert res == sol + + res = proto.decode(msg, type=Ex[int]) + assert res == sol + + res = proto.decode(msg, type=Ex[Union[int, str]]) + assert res == sol + + res = proto.decode(msg, type=Ex[float]) + assert type(res.x) is float + + with pytest.raises(msgspec.ValidationError, match="Expected `str`, got `int`"): + proto.decode(msg, type=Ex[str]) + + @pytest.mark.parametrize("array_like", [False, True]) + def test_recursive_generic_struct(self, proto, array_like): + source = f""" + from __future__ import annotations + from typing import Union, Generic, TypeVar + from msgspec import Struct + + T = TypeVar("T") + + class Ex(Struct, Generic[T], array_like={array_like}): + a: T + b: Union[Ex[T], None] + """ + + with temp_module(source) as mod: + msg = mod.Ex(a=1, b=mod.Ex(a=2, b=None)) + msg2 = mod.Ex(a=1, b=mod.Ex(a="bad", b=None)) + assert proto.decode(proto.encode(msg), type=mod.Ex) == msg + assert proto.decode(proto.encode(msg2), type=mod.Ex) == msg2 + assert proto.decode(proto.encode(msg), type=mod.Ex[int]) == msg + + with pytest.raises(msgspec.ValidationError) as rec: + proto.decode(proto.encode(msg2), type=mod.Ex[int]) + if array_like: + assert "`$[1][0]`" in str(rec.value) + else: + assert "`$.b.a`" in str(rec.value) + assert "Expected `int`, got `str`" in str(rec.value) + + @pytest.mark.parametrize("array_like", [False, True]) + def test_generic_struct_union(self, proto, array_like): + class Test1(msgspec.Struct, Generic[T], tag=True, array_like=array_like): + a: Union[T, None] + b: int + + class Test2(msgspec.Struct, Generic[T], tag=True, array_like=array_like): + x: T + y: int + + typ = Union[Test1[T], Test2[T]] + + msg1 = Test1(1, 2) + s1 = proto.encode(msg1) + msg2 = Test2("three", 4) + s2 = proto.encode(msg2) + msg3 = Test1(None, 4) + s3 = proto.encode(msg3) + + assert proto.decode(s1, type=typ) == msg1 + assert proto.decode(s2, type=typ) == msg2 + assert proto.decode(s3, type=typ) == msg3 + + assert proto.decode(s1, type=typ[int]) == msg1 + assert proto.decode(s3, type=typ[int]) == msg3 + assert proto.decode(s2, type=typ[str]) == msg2 + assert proto.decode(s3, type=typ[str]) == msg3 + + with pytest.raises(msgspec.ValidationError) as rec: + proto.decode(s1, type=typ[str]) + assert "Expected `str | null`, got `int`" in str(rec.value) + loc = "$[1]" if array_like else "$.a" + assert loc in str(rec.value) + + with pytest.raises(msgspec.ValidationError) as rec: + proto.decode(s2, type=typ[int]) + assert "Expected `int`, got `str`" in str(rec.value) + loc = "$[1]" if array_like else "$.x" + assert loc in str(rec.value) + + def test_unbound_typevars_use_bound_if_set(self, proto): + T = TypeVar("T", bound=Union[int, str]) + + dec = proto.Decoder(List[T]) + sol = [1, "two", 3, "four"] + msg = proto.encode(sol) + assert dec.decode(msg) == sol + + bad = proto.encode([1, {}]) + with pytest.raises( + msgspec.ValidationError, + match=r"Expected `int \| str`, got `object` - at `\$\[1\]`", + ): + dec.decode(bad) + + def test_unbound_typevars_with_constraints_unsupported(self, proto): + T = TypeVar("T", int, str) + with pytest.raises(TypeError) as rec: + proto.Decoder(List[T]) + + assert "Unbound TypeVar `~T` has constraints" in str(rec.value) + + class TestStructOmitDefaults: def test_omit_defaults(self, proto): class Test(msgspec.Struct, omit_defaults=True): diff --git a/tests/test_from_builtins.py b/tests/test_from_builtins.py index 58bdf41a..63628660 100644 --- a/tests/test_from_builtins.py +++ b/tests/test_from_builtins.py @@ -12,11 +12,13 @@ Any, Dict, FrozenSet, + Generic, List, Literal, NamedTuple, Set, Tuple, + TypeVar, Union, ) @@ -53,6 +55,8 @@ UTC = datetime.timezone.utc +T = TypeVar("T") + def assert_eq(x, y): assert x == y @@ -1560,6 +1564,72 @@ class Test3(Struct, tag=tag3, array_like=True): assert f"Invalid value {unknown!r} - at `$[0]`" == str(rec.value) +class TestGenericStruct: + @pytest.mark.parametrize("array_like", [False, True]) + def test_generic_struct(self, array_like): + class Ex(Struct, Generic[T], array_like=array_like): + x: T + y: List[T] + + sol = Ex(1, [1, 2]) + msg = to_builtins(sol) + + res = from_builtins(msg, Ex) + assert res == sol + + res = from_builtins(msg, Ex[int]) + assert res == sol + + res = from_builtins(msg, Ex[Union[int, str]]) + assert res == sol + + res = from_builtins(msg, Ex[float]) + assert type(res.x) is float + + with pytest.raises(ValidationError, match="Expected `str`, got `int`"): + from_builtins(msg, Ex[str]) + + @pytest.mark.parametrize("array_like", [False, True]) + def test_generic_struct_union(self, array_like): + class Test1(Struct, Generic[T], tag=True, array_like=array_like): + a: Union[T, None] + b: int + + class Test2(Struct, Generic[T], tag=True, array_like=array_like): + x: T + y: int + + typ = Union[Test1[T], Test2[T]] + + msg1 = Test1(1, 2) + s1 = to_builtins(msg1) + msg2 = Test2("three", 4) + s2 = to_builtins(msg2) + msg3 = Test1(None, 4) + s3 = to_builtins(msg3) + + assert from_builtins(s1, typ) == msg1 + assert from_builtins(s2, typ) == msg2 + assert from_builtins(s3, typ) == msg3 + + assert from_builtins(s1, typ[int]) == msg1 + assert from_builtins(s3, typ[int]) == msg3 + assert from_builtins(s2, typ[str]) == msg2 + assert from_builtins(s3, typ[str]) == msg3 + + with pytest.raises(ValidationError) as rec: + from_builtins(s1, typ[str]) + assert "Expected `str | null`, got `int`" in str(rec.value) + loc = "$[1]" if array_like else "$.a" + assert loc in str(rec.value) + + with pytest.raises(ValidationError) as rec: + from_builtins(s2, typ[int]) + assert "Expected `int`, got `str`" in str(rec.value) + loc = "$[1]" if array_like else "$.x" + assert loc in str(rec.value) + + class TestStrValues: def test_str_values_none(self): for x in ["null", "Null", "nUll", "nuLl", "nulL"]: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..dc24581e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import Generic, List, Set, TypeVar + +import pytest +from utils import temp_module + +from msgspec._utils import get_class_annotations + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") + + +class Base(Generic[T]): + x: T + + +class Base2(Generic[T, S]): + a: T + b: S + + +class TestGetClassAnnotations: + @pytest.mark.parametrize("future_annotations", [False, True]) + def test_eval_scopes(self, future_annotations): + header = "from __future__ import annotations" if future_annotations else "" + source = f""" + {header} + STR = str + + class Ex: + LOCAL = float + x: int + y: LOCAL + z: STR + """ + with temp_module(source) as mod: + assert get_class_annotations(mod.Ex) == {"x": int, "y": float, "z": str} + + def test_none_to_nonetype(self): + class Ex: + x: None + + assert get_class_annotations(Ex) == {"x": type(None)} + + def test_subclass(self): + class Base: + x: int + y: str + + class Sub(Base): + x: float + z: list + + class Base2: + a: int + + class Sub2(Sub, Base2): + b: float + y: list + + assert get_class_annotations(Base) == {"x": int, "y": str} + assert get_class_annotations(Sub) == {"x": float, "y": str, "z": list} + assert get_class_annotations(Sub2) == { + "x": float, + "y": list, + "z": list, + "a": int, + "b": float, + } + + def test_simple_generic(self): + class Test(Generic[T]): + x: T + y: List[T] + z: int + + assert get_class_annotations(Test) == {"x": T, "y": List[T], "z": int} + assert get_class_annotations(Test[int]) == {"x": int, "y": List[int], "z": int} + assert get_class_annotations(Test[Set[T]]) == { + "x": Set[T], + "y": List[Set[T]], + "z": int, + } + + def test_generic_sub1(self): + class Sub(Base): + y: int + + assert get_class_annotations(Sub) == {"x": T, "y": int} + + def test_generic_sub2(self): + class Sub(Base, Generic[T]): + y: List[T] + + assert get_class_annotations(Sub) == {"x": T, "y": List[T]} + assert get_class_annotations(Sub[int]) == {"x": T, "y": List[int]} + + def test_generic_sub3(self): + class Sub(Base[int], Generic[T]): + y: List[T] + + assert get_class_annotations(Sub) == {"x": int, "y": List[T]} + assert get_class_annotations(Sub[float]) == {"x": int, "y": List[float]} + + def test_generic_sub4(self): + class Sub(Base[T]): + y: List[T] + + assert get_class_annotations(Sub) == {"x": T, "y": List[T]} + assert get_class_annotations(Sub[int]) == {"x": int, "y": List[int]} + + def test_generic_sub5(self): + class Sub(Base[T], Generic[T]): + y: List[T] + + assert get_class_annotations(Sub) == {"x": T, "y": List[T]} + assert get_class_annotations(Sub[int]) == {"x": int, "y": List[int]} + + def test_generic_sub6(self): + class Sub(Base[S]): + y: List[S] + + assert get_class_annotations(Sub) == {"x": S, "y": List[S]} + assert get_class_annotations(Sub[int]) == {"x": int, "y": List[int]} + + def test_generic_sub7(self): + class Sub(Base[List[T]]): + y: Set[T] + + assert get_class_annotations(Sub) == {"x": List[T], "y": Set[T]} + assert get_class_annotations(Sub[int]) == {"x": List[int], "y": Set[int]} + + def test_generic_sub8(self): + class Sub(Base[int], Base2[float, str]): + pass + + assert get_class_annotations(Sub) == {"x": int, "a": float, "b": str} + + def test_generic_sub9(self): + class Sub(Base[U], Base2[List[U], U]): + y: str + + assert get_class_annotations(Sub) == {"y": str, "x": U, "a": List[U], "b": U} + assert get_class_annotations(Sub[int]) == { + "y": str, + "x": int, + "a": List[int], + "b": int, + } + + class Sub2(Sub[int]): + x: list + + assert get_class_annotations(Sub2) == { + "x": list, + "y": str, + "a": List[int], + "b": int, + } + + def test_generic_sub10(self): + class Sub(Base[U], Base2[List[U], U]): + y: str + + class Sub3(Sub[List[T]]): + c: T + + assert get_class_annotations(Sub3) == { + "c": T, + "y": str, + "x": List[T], + "a": List[List[T]], + "b": List[T], + } + assert get_class_annotations(Sub3[int]) == { + "c": int, + "y": str, + "x": List[int], + "a": List[List[int]], + "b": List[int], + }