From d9b7608980aa04672f5430e8fe90f08d947c0423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Wed, 25 Oct 2023 23:20:23 +0200 Subject: [PATCH] Support `Message.from_dict()` as a class and an instance method (#476) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make Message.from_dict() a class method Signed-off-by: Marek Pikuła * Sync 1/2 of review comments * Sync other half * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Update src/betterproto/__init__.py * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Fix CI again * Fix failing formatting --------- Signed-off-by: Marek Pikuła Co-authored-by: James Hilton-Balfe --- src/betterproto/__init__.py | 184 +++++++++++++++++++++--------------- src/betterproto/utils.py | 56 +++++++++++ 2 files changed, 164 insertions(+), 76 deletions(-) create mode 100644 src/betterproto/utils.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f52edaa1f..b2a63d836 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import enum as builtin_enum import json @@ -22,8 +24,8 @@ from typing import ( TYPE_CHECKING, Any, - BinaryIO, Callable, + ClassVar, Dict, Generator, Iterable, @@ -37,6 +39,7 @@ ) from dateutil.parser import isoparse +from typing_extensions import Self from ._types import T from ._version import __version__ @@ -47,6 +50,10 @@ ) from .enum import Enum as Enum from .grpc.grpclib_client import ServiceStub as ServiceStub +from .utils import ( + classproperty, + hybridmethod, +) if TYPE_CHECKING: @@ -729,6 +736,7 @@ class Message(ABC): _serialized_on_wire: bool _unknown_fields: bytes _group_current: Dict[str, str] + _betterproto_meta: ClassVar[ProtoClassMetadata] def __post_init__(self) -> None: # Keep track of whether every field was default @@ -882,18 +890,18 @@ def __copy__(self: T, _: Any = {}) -> T: kwargs[name] = value return self.__class__(**kwargs) # type: ignore - @property - def _betterproto(self) -> ProtoClassMetadata: + @classproperty + def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore """ Lazy initialize metadata for each protobuf class. It may be initialized multiple times in a multi-threaded environment, but that won't affect the correctness. """ - meta = getattr(self.__class__, "_betterproto_meta", None) - if not meta: - meta = ProtoClassMetadata(self.__class__) - self.__class__._betterproto_meta = meta # type: ignore - return meta + try: + return cls._betterproto_meta + except AttributeError: + cls._betterproto_meta = meta = ProtoClassMetadata(cls) + return meta def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: """ @@ -1512,10 +1520,74 @@ def to_dict( output[cased_name] = value return output - def from_dict(self: T, value: Mapping[str, Any]) -> T: + @classmethod + def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: + init_kwargs: Dict[str, Any] = {} + for key, value in mapping.items(): + field_name = safe_snake_case(key) + try: + meta = cls._betterproto.meta_by_field_name[field_name] + except KeyError: + continue + if value is None: + continue + + if meta.proto_type == TYPE_MESSAGE: + sub_cls = cls._betterproto.cls_by_field[field_name] + if sub_cls == datetime: + value = ( + [isoparse(item) for item in value] + if isinstance(value, list) + else isoparse(value) + ) + elif sub_cls == timedelta: + value = ( + [timedelta(seconds=float(item[:-1])) for item in value] + if isinstance(value, list) + else timedelta(seconds=float(value[:-1])) + ) + elif not meta.wraps: + value = ( + [sub_cls.from_dict(item) for item in value] + if isinstance(value, list) + else sub_cls.from_dict(value) + ) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"] + value = {k: sub_cls.from_dict(v) for k, v in value.items()} + else: + if meta.proto_type in INT_64_TYPES: + value = ( + [int(n) for n in value] + if isinstance(value, list) + else int(value) + ) + elif meta.proto_type == TYPE_BYTES: + value = ( + [b64decode(n) for n in value] + if isinstance(value, list) + else b64decode(value) + ) + elif meta.proto_type == TYPE_ENUM: + enum_cls = cls._betterproto.cls_by_field[field_name] + if isinstance(value, list): + value = [enum_cls.from_string(e) for e in value] + elif isinstance(value, str): + value = enum_cls.from_string(value) + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + value = ( + [_parse_float(n) for n in value] + if isinstance(value, list) + else _parse_float(value) + ) + + init_kwargs[field_name] = value + return init_kwargs + + @hybridmethod + def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore """ - Parse the key/value pairs into the current message instance. This returns the - instance itself and is therefore assignable and chainable. + Parse the key/value pairs into the a new message instance. Parameters ----------- @@ -1527,72 +1599,29 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T: :class:`Message` The initialized message. """ + self = cls(**cls._from_dict_init(value)) self._serialized_on_wire = True - for key in value: - field_name = safe_snake_case(key) - meta = self._betterproto.meta_by_field_name.get(field_name) - if not meta: - continue + return self - if value[key] is not None: - if meta.proto_type == TYPE_MESSAGE: - v = self._get_field_default(field_name) - cls = self._betterproto.cls_by_field[field_name] - if isinstance(v, list): - if cls == datetime: - v = [isoparse(item) for item in value[key]] - elif cls == timedelta: - v = [ - timedelta(seconds=float(item[:-1])) - for item in value[key] - ] - else: - v = [cls().from_dict(item) for item in value[key]] - elif cls == datetime: - v = isoparse(value[key]) - setattr(self, field_name, v) - elif cls == timedelta: - v = timedelta(seconds=float(value[key][:-1])) - setattr(self, field_name, v) - elif meta.wraps: - setattr(self, field_name, value[key]) - elif v is None: - setattr(self, field_name, cls().from_dict(value[key])) - else: - # NOTE: `from_dict` mutates the underlying message, so no - # assignment here is necessary. - v.from_dict(value[key]) - elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field_name) - cls = self._betterproto.cls_by_field[f"{field_name}.value"] - for k in value[key]: - v[k] = cls().from_dict(value[key][k]) - else: - v = value[key] - if meta.proto_type in INT_64_TYPES: - if isinstance(value[key], list): - v = [int(n) for n in value[key]] - else: - v = int(value[key]) - elif meta.proto_type == TYPE_BYTES: - if isinstance(value[key], list): - v = [b64decode(n) for n in value[key]] - else: - v = b64decode(value[key]) - elif meta.proto_type == TYPE_ENUM: - enum_cls = self._betterproto.cls_by_field[field_name] - if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) - elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): - if isinstance(value[key], list): - v = [_parse_float(n) for n in value[key]] - else: - v = _parse_float(value[key]) + @from_dict.instancemethod + def from_dict(self, value: Mapping[str, Any]) -> Self: + """ + Parse the key/value pairs into the current message instance. This returns the + instance itself and is therefore assignable and chainable. - if v is not None: - setattr(self, field_name, v) + Parameters + ----------- + value: Dict[:class:`str`, Any] + The dictionary to parse from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + self._serialized_on_wire = True + for field, value in self._from_dict_init(value).items(): + setattr(self, field, value) return self def to_json( @@ -1809,8 +1838,8 @@ def is_set(self, name: str) -> bool: @classmethod def _validate_field_groups(cls, values): - group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore - field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore + group_to_one_ofs = cls._betterproto.oneof_field_by_group + field_name_to_meta = cls._betterproto.meta_by_field_name for group, field_set in group_to_one_ofs.items(): if len(field_set) == 1: @@ -1837,6 +1866,9 @@ def _validate_field_groups(cls, values): return values +Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :) + + def serialized_on_wire(message: Message) -> bool: """ If this message was or should be serialized on the wire. This can be used to detect diff --git a/src/betterproto/utils.py b/src/betterproto/utils.py new file mode 100644 index 000000000..b977fc713 --- /dev/null +++ b/src/betterproto/utils.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import ( + Any, + Callable, + Generic, + Optional, + Type, + TypeVar, +) + +from typing_extensions import ( + Concatenate, + ParamSpec, + Self, +) + + +SelfT = TypeVar("SelfT") +P = ParamSpec("P") +HybridT = TypeVar("HybridT", covariant=True) + + +class hybridmethod(Generic[SelfT, P, HybridT]): + def __init__( + self, + func: Callable[ + Concatenate[type[SelfT], P], HybridT + ], # Must be the classmethod version + ): + self.cls_func = func + self.__doc__ = func.__doc__ + + def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self: + self.instance_func = func + return self + + def __get__( + self, instance: Optional[SelfT], owner: Type[SelfT] + ) -> Callable[P, HybridT]: + if instance is None or self.instance_func is None: + # either bound to the class, or no instance method available + return self.cls_func.__get__(owner, None) + return self.instance_func.__get__(instance, owner) + + +T_co = TypeVar("T_co") +TT_co = TypeVar("TT_co", bound="type[Any]") + + +class classproperty(Generic[TT_co, T_co]): + def __init__(self, func: Callable[[TT_co], T_co]): + self.__func__ = func + + def __get__(self, instance: Any, type: TT_co) -> T_co: + return self.__func__(type)