From d66db44dfe8072f0b8bac2770968d996664e683b Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Sun, 5 Apr 2020 11:31:11 -0400 Subject: [PATCH] Decouple generator and filepath to fix pickle bug (#4026) * Decouple generator and filepath to fix pickle bug * Update allennlp/data/dataset_readers/dataset_reader.py Co-authored-by: Matt Gardner --- allennlp/data/dataset_readers/dataset_reader.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/allennlp/data/dataset_readers/dataset_reader.py b/allennlp/data/dataset_readers/dataset_reader.py index 17adddeab16..22f2265ea98 100644 --- a/allennlp/data/dataset_readers/dataset_reader.py +++ b/allennlp/data/dataset_readers/dataset_reader.py @@ -41,7 +41,8 @@ class _LazyInstances(IterableDataset): def __init__( self, - instance_generator: Callable[[], Iterable[Instance]], + instance_generator: Callable[[str], Iterable[Instance]], + file_path: str, cache_file: str = None, deserialize: Callable[[str], Instance] = None, serialize: Callable[[Instance], str] = None, @@ -49,6 +50,7 @@ def __init__( ) -> None: super().__init__() self.instance_generator = instance_generator + self.file_path = file_path self.cache_file = cache_file self.deserialize = deserialize self.serialize = serialize @@ -67,7 +69,7 @@ def __iter__(self) -> Iterator[Instance]: # Case 2: Need to cache instances elif self.cache_file is not None: with open(self.cache_file, "w") as data_file: - for instance in self.instance_generator(): + for instance in self.instance_generator(self.file_path): data_file.write(self.serialize(instance)) data_file.write("\n") if self.vocab is not None: @@ -75,7 +77,7 @@ def __iter__(self) -> Iterator[Instance]: yield instance # Case 3: No cache else: - instances = self.instance_generator() + instances = self.instance_generator(self.file_path) if isinstance(instances, list): raise ConfigurationError( "For a lazy dataset reader, _read() must return a generator" @@ -176,7 +178,8 @@ def read(self, file_path: str) -> Dataset: if lazy: instances: Iterable[Instance] = _LazyInstances( - lambda: self._read(file_path), + self._read, + file_path, cache_file, self.deserialize_instance, self.serialize_instance,