Skip to content

Commit

Permalink
Add tests for handling Generic structs
Browse files Browse the repository at this point in the history
Also fixes a few bugs
  • Loading branch information
jcrist committed Apr 22, 2023
1 parent 42d66aa commit e0d1001
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 24 deletions.
24 changes: 15 additions & 9 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ typedef struct {
PyObject *typing_classvar;
PyObject *typing_typevar;
PyObject *typing_final;
PyObject *typing_generic;
PyObject *typing_generic_alias;
PyObject *typing_annotated_alias;
PyObject *concrete_types;
Expand Down Expand Up @@ -4963,7 +4964,7 @@ structmeta_get_module_ns(MsgspecState *mod, StructMetaInfo *info) {
}

static int
structmeta_collect_base(StructMetaInfo *info, PyObject *base) {
structmeta_collect_base(StructMetaInfo *info, MsgspecState *mod, PyObject *base) {
if ((PyTypeObject *)base == &StructMixinType) return 0;

if (((PyTypeObject *)base)->tp_weaklistoffset) {
Expand All @@ -4983,17 +4984,21 @@ structmeta_collect_base(StructMetaInfo *info, PyObject *base) {
}

if (Py_TYPE(base) != &StructMetaType) {
PyTypeObject *cls = (PyTypeObject *)base;
info->has_non_struct_bases = true;
/* XXX: in Python 3.8 Generic defines __new__, but we can ignore it.
* This can be removed when Python 3.8 support is dropped */
if (base == mod->typing_generic) return 0;

static const char *attrs[] = {"__init__", "__new__"};
Py_ssize_t nattrs = 2;

for (Py_ssize_t i = 0; i < nattrs; i++) {
if (PyDict_GetItemString(cls->tp_dict, attrs[i]) != NULL) {
if (PyDict_GetItemString(
((PyTypeObject *)base)->tp_dict, attrs[i]) != NULL
) {
PyErr_Format(PyExc_TypeError, "Struct base classes cannot define %s", attrs[i]);
return -1;
}
}
info->has_non_struct_bases = true;
return 0;
}

Expand Down Expand Up @@ -5643,7 +5648,7 @@ StructMeta_new_inner(
/* Extract info from base classes in reverse MRO order */
for (Py_ssize_t i = PyTuple_GET_SIZE(bases) - 1; i >= 0; i--) {
PyObject *base = PyTuple_GET_ITEM(bases, i);
if (structmeta_collect_base(&info, base) < 0) goto cleanup;
if (structmeta_collect_base(&info, mod, base) < 0) goto cleanup;
}

/* Process configuration options */
Expand Down Expand Up @@ -6098,9 +6103,7 @@ StructInfo_Convert(PyObject *obj) {
class->struct_info = info;
}
else {
if (PyObject_SetAttr(obj, mod->str___msgspec_cache__, (PyObject *)info) < 0) {
goto error;
}
if (PyObject_SetAttr(obj, mod->str___msgspec_cache__, (PyObject *)info) < 0) goto error;
}
cache_set = true;

Expand Down Expand Up @@ -18528,6 +18531,7 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->typing_classvar);
Py_CLEAR(st->typing_typevar);
Py_CLEAR(st->typing_final);
Py_CLEAR(st->typing_generic);
Py_CLEAR(st->typing_generic_alias);
Py_CLEAR(st->typing_annotated_alias);
Py_CLEAR(st->concrete_types);
Expand Down Expand Up @@ -18593,6 +18597,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(st->typing_classvar);
Py_VISIT(st->typing_typevar);
Py_VISIT(st->typing_final);
Py_VISIT(st->typing_generic);
Py_VISIT(st->typing_generic_alias);
Py_VISIT(st->typing_annotated_alias);
Py_VISIT(st->concrete_types);
Expand Down Expand Up @@ -18805,6 +18810,7 @@ PyInit__core(void)
SET_REF(typing_classvar, "ClassVar");
SET_REF(typing_typevar, "TypeVar");
SET_REF(typing_final, "Final");
SET_REF(typing_generic, "Generic");
SET_REF(typing_generic_alias, "_GenericAlias");
Py_DECREF(temp_module);

Expand Down
52 changes: 37 additions & 15 deletions msgspec/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# type: ignore
import collections
import sys
import types
import typing

try:
Expand Down Expand Up @@ -35,6 +34,32 @@ def get_type_hints(obj):
return _get_type_hints(obj, include_extras=True)


# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
# check is to try it and see. This check can be removed when we drop support
# for Python 3.10.
try:
typing.ForwardRef("Foo", is_class=True)
except TypeError:

def _forward_ref(value):
return typing.ForwardRef(value, is_argument=False)

else:

def _forward_ref(value):
return typing.ForwardRef(value, is_argument=False, is_class=True)


def _apply_params(obj, mapping):
if params := getattr(obj, "__parameters__", None):
args = tuple(mapping.get(p, p) for p in params)
return obj[args]
elif isinstance(obj, typing.TypeVar):
return mapping.get(obj, obj)
return obj


def _get_class_mro_and_typevar_mappings(obj):
mapping = {}

Expand All @@ -43,22 +68,26 @@ def _get_class_mro_and_typevar_mappings(obj):
else:
cls = obj.__origin__

def inner(c):
def inner(c, scope):
if isinstance(c, type):
cls = c
new_scope = {}
else:
cls = c.__origin__
if cls in (object, typing.Generic):
return
if cls not in mapping:
mapping[cls] = dict(zip(cls.__parameters__, c.__args__))
params = cls.__parameters__
args = tuple(_apply_params(a, scope) for a in c.__args__)
assert len(params) == len(args)
mapping[cls] = new_scope = dict(zip(params, args))

if issubclass(cls, typing.Generic):
bases = getattr(cls, "__orig_bases__", cls.__bases__)
for b in bases:
inner(b)
inner(b, new_scope)

inner(obj)
inner(obj, {})
return cls.__mro__, mapping


Expand Down Expand Up @@ -92,23 +121,16 @@ def get_class_annotations(obj):
cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})

ann = cls.__dict__.get("__annotations__", {})
if isinstance(ann, types.GetSetDescriptorType):
ann = {}

for name, value in ann.items():
if name in hints:
continue
if value is None:
value = type(None)
if isinstance(value, str):
value = typing.ForwardRef(value, is_argument=False, is_class=True)
elif isinstance(value, str):
value = _forward_ref(value)
value = typing._eval_type(value, cls_locals, cls_globals)
if mapping is not None:
if params := getattr(value, "__parameters__", None):
args = tuple(mapping.get(p, p) for p in params)
value = value[args]
elif isinstance(value, typing.TypeVar):
value = mapping.get(value, value)
value = _apply_params(value, mapping)
hints[name] = value
return hints

Expand Down
161 changes: 161 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
Deque,
Dict,
Final,
Generic,
List,
Literal,
NamedTuple,
NewType,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
)

Expand All @@ -45,6 +47,8 @@
py310_plus = pytest.mark.skipif(not PY310, reason="3.10+ only")
py311_plus = pytest.mark.skipif(not PY311, reason="3.11+ only")

T = TypeVar("T")


@pytest.fixture(params=["json", "msgpack"])
def proto(request):
Expand Down Expand Up @@ -1065,6 +1069,163 @@ class Test2(msgspec.Struct, tag=True):
assert frozenset(new) in cache


class TestGenericStruct:
def test_generic_struct_info_cached(self, proto):
class Ex(msgspec.Struct, Generic[T]):
x: T

typ = Ex[int]
assert Ex[int] is typ

dec = proto.Decoder(typ)
info = typ.__msgspec_cache__
assert info is not None
assert sys.getrefcount(info) == 4 # info + attr + decoder + func call
dec2 = proto.Decoder(typ)
assert typ.__msgspec_cache__ is info
assert sys.getrefcount(info) == 5

del dec
del dec2
assert sys.getrefcount(info) == 3

def test_generic_struct_invalid_types_not_cached(self, proto):
class Ex(msgspec.Struct, Generic[T]):
x: Union[List[T], Tuple[float]]

for typ in [Ex, Ex[int]]:
for _ in range(2):
with pytest.raises(TypeError, match="not supported"):
proto.Decoder(typ)

assert not hasattr(typ, "__msgspec_cache__")

def test_msgspec_cache_overwritten(self, proto):
class Ex(msgspec.Struct, Generic[T]):
x: T

typ = Ex[int]
typ.__msgspec_cache__ = 1

with pytest.raises(RuntimeError, match="__msgspec_cache__"):
proto.Decoder(typ)

@pytest.mark.parametrize("array_like", [False, True])
def test_generic_struct(self, proto, array_like):
class Ex(msgspec.Struct, Generic[T], array_like=array_like):
x: T
y: List[T]

sol = Ex(1, [1, 2])
msg = proto.encode(sol)

res = proto.decode(msg, type=Ex)
assert res == sol

res = proto.decode(msg, type=Ex[int])
assert res == sol

res = proto.decode(msg, type=Ex[Union[int, str]])
assert res == sol

res = proto.decode(msg, type=Ex[float])
assert type(res.x) is float

with pytest.raises(msgspec.ValidationError, match="Expected `str`, got `int`"):
proto.decode(msg, type=Ex[str])

@pytest.mark.parametrize("array_like", [False, True])
def test_recursive_generic_struct(self, proto, array_like):
source = f"""
from __future__ import annotations
from typing import Union, Generic, TypeVar
from msgspec import Struct
T = TypeVar("T")
class Ex(Struct, Generic[T], array_like={array_like}):
a: T
b: Union[Ex[T], None]
"""

with temp_module(source) as mod:
msg = mod.Ex(a=1, b=mod.Ex(a=2, b=None))
msg2 = mod.Ex(a=1, b=mod.Ex(a="bad", b=None))
assert proto.decode(proto.encode(msg), type=mod.Ex) == msg
assert proto.decode(proto.encode(msg2), type=mod.Ex) == msg2
assert proto.decode(proto.encode(msg), type=mod.Ex[int]) == msg

with pytest.raises(msgspec.ValidationError) as rec:
proto.decode(proto.encode(msg2), type=mod.Ex[int])
if array_like:
assert "`$[1][0]`" in str(rec.value)
else:
assert "`$.b.a`" in str(rec.value)
assert "Expected `int`, got `str`" in str(rec.value)

@pytest.mark.parametrize("array_like", [False, True])
def test_generic_struct_union(self, proto, array_like):
class Test1(msgspec.Struct, Generic[T], tag=True, array_like=array_like):
a: Union[T, None]
b: int

class Test2(msgspec.Struct, Generic[T], tag=True, array_like=array_like):
x: T
y: int

typ = Union[Test1[T], Test2[T]]

msg1 = Test1(1, 2)
s1 = proto.encode(msg1)
msg2 = Test2("three", 4)
s2 = proto.encode(msg2)
msg3 = Test1(None, 4)
s3 = proto.encode(msg3)

assert proto.decode(s1, type=typ) == msg1
assert proto.decode(s2, type=typ) == msg2
assert proto.decode(s3, type=typ) == msg3

assert proto.decode(s1, type=typ[int]) == msg1
assert proto.decode(s3, type=typ[int]) == msg3
assert proto.decode(s2, type=typ[str]) == msg2
assert proto.decode(s3, type=typ[str]) == msg3

with pytest.raises(msgspec.ValidationError) as rec:
proto.decode(s1, type=typ[str])
assert "Expected `str | null`, got `int`" in str(rec.value)
loc = "$[1]" if array_like else "$.a"
assert loc in str(rec.value)

with pytest.raises(msgspec.ValidationError) as rec:
proto.decode(s2, type=typ[int])
assert "Expected `int`, got `str`" in str(rec.value)
loc = "$[1]" if array_like else "$.x"
assert loc in str(rec.value)

def test_unbound_typevars_use_bound_if_set(self, proto):
T = TypeVar("T", bound=Union[int, str])

dec = proto.Decoder(List[T])
sol = [1, "two", 3, "four"]
msg = proto.encode(sol)
assert dec.decode(msg) == sol

bad = proto.encode([1, {}])
with pytest.raises(
msgspec.ValidationError,
match=r"Expected `int \| str`, got `object` - at `\$\[1\]`",
):
dec.decode(bad)

def test_unbound_typevars_with_constraints_unsupported(self, proto):
T = TypeVar("T", int, str)
with pytest.raises(TypeError) as rec:
proto.Decoder(List[T])

assert "Unbound TypeVar `~T` has constraints" in str(rec.value)


class TestStructOmitDefaults:
def test_omit_defaults(self, proto):
class Test(msgspec.Struct, omit_defaults=True):
Expand Down
Loading

0 comments on commit e0d1001

Please sign in to comment.