Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type hints pickling in python 3.6 #818

Merged
merged 1 commit into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import functools
import itertools
import os
import pickle
import sys
import types
from io import BytesIO as StringIO
from multiprocessing import Pool, RLock
from shutil import disk_usage
from types import CodeType
from typing import Optional
from typing import Callable, ClassVar, Generic, Optional, Tuple, Union

import dill
import numpy as np
Expand All @@ -36,6 +38,13 @@
from .logging import INFO, WARNING, get_logger, get_verbosity, set_verbosity_warning


try: # pragma: no branch
import typing_extensions as _typing_extensions
from typing_extensions import Final, Literal
except ImportError:
_typing_extensions = Literal = Final = None


# NOTE: When used on an instance method, the cache is shared across all
# instances and IS NOT per-instance.
# See
Expand Down Expand Up @@ -333,6 +342,19 @@ class Pickler(dill.Pickler):

dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())

def save_global(self, obj, name=None):
if sys.version_info[:2] < (3, 7) and _CloudPickleTypeHintFix._is_parametrized_type_hint(
obj
): # noqa # pragma: no branch
# Parametrized typing constructs in Python < 3.7 are not compatible
# with type checks and ``isinstance`` semantics. For this reason,
# it is easier to detect them using a duck-typing-based check
# (``_is_parametrized_type_hint``) than to populate the Pickler's
# dispatch with type-specific savers.
_CloudPickleTypeHintFix._save_parametrized_type_hint(self, obj)
else:
dill.Pickler.save_global(self, obj, name=name)


def dump(obj, file):
"""pickle an object to a file"""
Expand Down Expand Up @@ -376,6 +398,47 @@ def proxy(func):
return proxy


class _CloudPickleTypeHintFix:
"""
Type hints can't be properly pickled in python < 3.7
CloudPickle provided a way to make it work in older versions.
This class provide utilities to fix pickling of type hints in older versions.
from /~https://github.com/cloudpipe/cloudpickle/pull/318/files
"""

def _is_parametrized_type_hint(obj):
# This is very cheap but might generate false positives.
origin = getattr(obj, "__origin__", None) # typing Constructs
values = getattr(obj, "__values__", None) # typing_extensions.Literal
type_ = getattr(obj, "__type__", None) # typing_extensions.Final
return origin is not None or values is not None or type_ is not None

def _create_parametrized_type_hint(origin, args):
return origin[args]

def _save_parametrized_type_hint(pickler, obj):
# The distorted type check sematic for typing construct becomes:
# ``type(obj) is type(TypeHint)``, which means "obj is a
# parametrized TypeHint"
if type(obj) is type(Literal): # pragma: no branch
initargs = (Literal, obj.__values__)
elif type(obj) is type(Final): # pragma: no branch
initargs = (Final, obj.__type__)
elif type(obj) is type(ClassVar):
initargs = (ClassVar, obj.__type__)
elif type(obj) in [type(Union), type(Tuple), type(Generic)]:
initargs = (obj.__origin__, obj.__args__)
elif type(obj) is type(Callable):
args = obj.__args__
if args[0] is Ellipsis:
initargs = (obj.__origin__, args)
else:
initargs = (obj.__origin__, (list(args[:-1]), args[-1]))
else: # pragma: no cover
raise pickle.PicklingError("Datasets pickle Error: Unknown type {}".format(type(obj)))
pickler.save_reduce(_CloudPickleTypeHintFix._create_parametrized_type_hint, initargs, obj=obj)


@pklregister(CodeType)
def _save_code(pickler, obj):
"""
Expand Down
15 changes: 15 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,18 @@ def func():
hash3 = md5(datasets.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)


class TypeHintDumpTest(TestCase):
def test_dump_type_hint(self):
from typing import Union

t1 = Union[str, None] # this type is not picklable in python 3.6
# let's check that we can pickle it anyway using our pickler, even in 3.6
hash1 = md5(datasets.utils.dumps(t1)).hexdigest()
t2 = Union[str] # this type is picklable in python 3.6
hash2 = md5(datasets.utils.dumps(t2)).hexdigest()
t3 = Union[str, None]
hash3 = md5(datasets.utils.dumps(t3)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)