Skip to content

Commit

Permalink
fix type hints pickling in python 3.6
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Nov 9, 2020
1 parent 92acf1e commit 7ec6af5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
65 changes: 64 additions & 1 deletion src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import functools
import itertools
import os
import pickle
import sys
import types
from io import BytesIO as StringIO
from multiprocessing import Pool, RLock
from shutil import disk_usage
from types import CodeType
from typing import Optional
from typing import Callable, ClassVar, Generic, Optional, Tuple, Union

import dill
import numpy as np
Expand All @@ -36,6 +38,13 @@
from .logging import INFO, WARNING, get_logger, get_verbosity, set_verbosity_warning


try: # pragma: no branch
import typing_extensions as _typing_extensions
from typing_extensions import Final, Literal
except ImportError:
_typing_extensions = Literal = Final = None


# NOTE: When used on an instance method, the cache is shared across all
# instances and IS NOT per-instance.
# See
Expand Down Expand Up @@ -333,6 +342,19 @@ class Pickler(dill.Pickler):

dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())

def save_global(self, obj, name=None):
if sys.version_info[:2] < (3, 7) and _CloudPickleTypeHintFix._is_parametrized_type_hint(
obj
): # noqa # pragma: no branch
# Parametrized typing constructs in Python < 3.7 are not compatible
# with type checks and ``isinstance`` semantics. For this reason,
# it is easier to detect them using a duck-typing-based check
# (``_is_parametrized_type_hint``) than to populate the Pickler's
# dispatch with type-specific savers.
_CloudPickleTypeHintFix._save_parametrized_type_hint(self, obj)
else:
dill.Pickler.save_global(self, obj, name=name)


def dump(obj, file):
"""pickle an object to a file"""
Expand Down Expand Up @@ -376,6 +398,47 @@ def proxy(func):
return proxy


class _CloudPickleTypeHintFix:
"""
Type hints can't be properly pickled in python < 3.7
CloudPickle provided a way to make it work in older versions.
This class provide utilities to fix pickling of type hints in older versions.
from /~https://github.com/cloudpipe/cloudpickle/pull/318/files
"""

def _is_parametrized_type_hint(obj):
# This is very cheap but might generate false positives.
origin = getattr(obj, "__origin__", None) # typing Constructs
values = getattr(obj, "__values__", None) # typing_extensions.Literal
type_ = getattr(obj, "__type__", None) # typing_extensions.Final
return origin is not None or values is not None or type_ is not None

def _create_parametrized_type_hint(origin, args):
return origin[args]

def _save_parametrized_type_hint(pickler, obj):
# The distorted type check sematic for typing construct becomes:
# ``type(obj) is type(TypeHint)``, which means "obj is a
# parametrized TypeHint"
if type(obj) is type(Literal): # pragma: no branch
initargs = (Literal, obj.__values__)
elif type(obj) is type(Final): # pragma: no branch
initargs = (Final, obj.__type__)
elif type(obj) is type(ClassVar):
initargs = (ClassVar, obj.__type__)
elif type(obj) in [type(Union), type(Tuple), type(Generic)]:
initargs = (obj.__origin__, obj.__args__)
elif type(obj) is type(Callable):
args = obj.__args__
if args[0] is Ellipsis:
initargs = (obj.__origin__, args)
else:
initargs = (obj.__origin__, (list(args[:-1]), args[-1]))
else: # pragma: no cover
raise pickle.PicklingError("Datasets pickle Error: Unknown type {}".format(type(obj)))
pickler.save_reduce(_CloudPickleTypeHintFix._create_parametrized_type_hint, initargs, obj=obj)


@pklregister(CodeType)
def _save_code(pickler, obj):
"""
Expand Down
15 changes: 15 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,18 @@ def func():
hash3 = md5(datasets.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)


class TypeHintDumpTest(TestCase):
def test_dump_type_hint(self):
from typing import Union

t1 = Union[str, None] # this type is not picklable in python 3.6
# let's check that we can pickle it anyway using our pickler, even in 3.6
hash1 = md5(datasets.utils.dumps(t1)).hexdigest()
t2 = Union[str] # this type is picklable in python 3.6
hash2 = md5(datasets.utils.dumps(t2)).hexdigest()
t3 = Union[str, None]
hash3 = md5(datasets.utils.dumps(t3)).hexdigest()
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

1 comment on commit 7ec6af5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==0.17.1

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.012964 / 0.011353 (0.001611) 0.011070 / 0.011008 (0.000061) 0.042586 / 0.038508 (0.004078) 0.025096 / 0.023109 (0.001987) 0.153759 / 0.275898 (-0.122139) 0.179725 / 0.323480 (-0.143754) 0.007609 / 0.007986 (-0.000376) 0.003044 / 0.004328 (-0.001285) 0.005701 / 0.004250 (0.001450) 0.040897 / 0.037052 (0.003845) 0.154687 / 0.258489 (-0.103802) 0.179993 / 0.293841 (-0.113848) 0.115192 / 0.128546 (-0.013354) 0.089182 / 0.075646 (0.013536) 0.369674 / 0.419271 (-0.049597) 0.431605 / 0.043533 (0.388072) 0.158559 / 0.255139 (-0.096580) 0.174433 / 0.283200 (-0.108767) 0.072603 / 0.141683 (-0.069080) 1.424952 / 1.452155 (-0.027203) 1.420686 / 1.492716 (-0.072030)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.030378 / 0.037411 (-0.007034) 0.015262 / 0.014526 (0.000736) 0.070413 / 0.176557 (-0.106144) 0.078814 / 0.737135 (-0.658321) 0.020772 / 0.296338 (-0.275567)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.140577 / 0.215209 (-0.074632) 1.444021 / 2.077655 (-0.633634) 0.907187 / 1.504120 (-0.596933) 0.853988 / 1.541195 (-0.687207) 0.872900 / 1.468490 (-0.595590) 4.773943 / 4.584777 (0.189166) 3.995913 / 3.745712 (0.250200) 5.929908 / 5.269862 (0.660047) 5.056925 / 4.565676 (0.491248) 0.496123 / 0.424275 (0.071848) 0.010750 / 0.007607 (0.003143) 0.165365 / 0.226044 (-0.060680) 1.737241 / 2.268929 (-0.531687) 1.282524 / 55.444624 (-54.162101) 1.216441 / 6.876477 (-5.660035) 1.253479 / 2.142072 (-0.888593) 5.205943 / 4.805227 (0.400716) 3.619288 / 6.500664 (-2.881376) 4.361920 / 0.075469 (4.286451)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 7.585461 / 1.841788 (5.743673) 11.728997 / 8.074308 (3.654688) 9.431122 / 10.191392 (-0.760270) 0.345328 / 0.680424 (-0.335096) 0.206170 / 0.534201 (-0.328031) 0.587908 / 0.579283 (0.008625) 0.426395 / 0.434364 (-0.007969) 0.553510 / 0.540337 (0.013172) 1.259951 / 1.386936 (-0.126986)
PyArrow==1.0
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.014579 / 0.011353 (0.003226) 0.011559 / 0.011008 (0.000551) 0.041603 / 0.038508 (0.003095) 0.026034 / 0.023109 (0.002924) 0.250851 / 0.275898 (-0.025047) 0.287731 / 0.323480 (-0.035749) 0.007695 / 0.007986 (-0.000291) 0.004399 / 0.004328 (0.000071) 0.008338 / 0.004250 (0.004088) 0.044324 / 0.037052 (0.007272) 0.257018 / 0.258489 (-0.001472) 0.288093 / 0.293841 (-0.005748) 0.112700 / 0.128546 (-0.015847) 0.084876 / 0.075646 (0.009230) 0.384546 / 0.419271 (-0.034725) 0.327054 / 0.043533 (0.283521) 0.246824 / 0.255139 (-0.008315) 0.267401 / 0.283200 (-0.015798) 0.082065 / 0.141683 (-0.059618) 1.402861 / 1.452155 (-0.049294) 1.444276 / 1.492716 (-0.048441)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.031462 / 0.037411 (-0.005949) 0.018522 / 0.014526 (0.003996) 0.021857 / 0.176557 (-0.154699) 0.072322 / 0.737135 (-0.664813) 0.022806 / 0.296338 (-0.273533)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.199325 / 0.215209 (-0.015884) 2.034104 / 2.077655 (-0.043551) 1.548881 / 1.504120 (0.044761) 1.496135 / 1.541195 (-0.045059) 1.532151 / 1.468490 (0.063660) 4.758039 / 4.584777 (0.173262) 3.896860 / 3.745712 (0.151148) 5.981793 / 5.269862 (0.711932) 5.103447 / 4.565676 (0.537771) 0.466463 / 0.424275 (0.042188) 0.009112 / 0.007607 (0.001505) 0.225767 / 0.226044 (-0.000277) 2.215483 / 2.268929 (-0.053445) 1.819499 / 55.444624 (-53.625126) 1.772768 / 6.876477 (-5.103709) 1.845216 / 2.142072 (-0.296856) 4.836211 / 4.805227 (0.030984) 3.293162 / 6.500664 (-3.207502) 5.638515 / 0.075469 (5.563046)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 7.770001 / 1.841788 (5.928214) 12.330611 / 8.074308 (4.256303) 10.102264 / 10.191392 (-0.089128) 0.799090 / 0.680424 (0.118666) 0.436089 / 0.534201 (-0.098112) 0.575563 / 0.579283 (-0.003720) 0.429847 / 0.434364 (-0.004517) 0.537631 / 0.540337 (-0.002707) 1.226957 / 1.386936 (-0.159979)

CML watermark

Please sign in to comment.