Skip to content

Commit

Permalink
Support Message.from_dict() as a class and an instance method (#476)
Browse files Browse the repository at this point in the history
* Make Message.from_dict() a class method

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>

* Sync 1/2 of review comments

* Sync other half

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Update src/betterproto/__init__.py

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Fix CI again

* Fix failing formatting

---------

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
  • Loading branch information
MarekPikula and Gobot1234 authored Oct 25, 2023
1 parent 02aa4e8 commit d9b7608
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 76 deletions.
184 changes: 108 additions & 76 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import enum as builtin_enum
import json
Expand All @@ -22,8 +24,8 @@
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
Expand All @@ -37,6 +39,7 @@
)

from dateutil.parser import isoparse
from typing_extensions import Self

from ._types import T
from ._version import __version__
Expand All @@ -47,6 +50,10 @@
)
from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -729,6 +736,7 @@ class Message(ABC):
_serialized_on_wire: bool
_unknown_fields: bytes
_group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]

def __post_init__(self) -> None:
# Keep track of whether every field was default
Expand Down Expand Up @@ -882,18 +890,18 @@ def __copy__(self: T, _: Any = {}) -> T:
kwargs[name] = value
return self.__class__(**kwargs) # type: ignore

@property
def _betterproto(self) -> ProtoClassMetadata:
@classproperty
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness.
"""
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta # type: ignore
return meta
try:
return cls._betterproto_meta
except AttributeError:
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta

def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
"""
Expand Down Expand Up @@ -1512,10 +1520,74 @@ def to_dict(
output[cased_name] = value
return output

def from_dict(self: T, value: Mapping[str, Any]) -> T:
@classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue

if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)

init_kwargs[field_name] = value
return init_kwargs

@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable.
Parse the key/value pairs into the a new message instance.
Parameters
-----------
Expand All @@ -1527,72 +1599,29 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
return self

if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
"""
Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable.
if v is not None:
setattr(self, field_name, v)
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self._serialized_on_wire = True
for field, value in self._from_dict_init(value).items():
setattr(self, field, value)
return self

def to_json(
Expand Down Expand Up @@ -1809,8 +1838,8 @@ def is_set(self, name: str) -> bool:

@classmethod
def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto.meta_by_field_name

for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1:
Expand All @@ -1837,6 +1866,9 @@ def _validate_field_groups(cls, values):
return values


Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)


def serialized_on_wire(message: Message) -> bool:
"""
If this message was or should be serialized on the wire. This can be used to detect
Expand Down
56 changes: 56 additions & 0 deletions src/betterproto/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)

from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)


SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)


class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__

def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self

def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)


T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")


class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func

def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)

0 comments on commit d9b7608

Please sign in to comment.