Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Decouple generator and filepath to fix pickle bug (#4026)
Browse files Browse the repository at this point in the history
* Decouple generator and filepath to fix pickle bug

* Update allennlp/data/dataset_readers/dataset_reader.py

Co-authored-by: Matt Gardner <mattg@allenai.org>
  • Loading branch information
JohnGiorgi and matt-gardner authored Apr 5, 2020
1 parent 2ca9b5e commit d66db44
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions allennlp/data/dataset_readers/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ class _LazyInstances(IterableDataset):

def __init__(
self,
instance_generator: Callable[[], Iterable[Instance]],
instance_generator: Callable[[str], Iterable[Instance]],
file_path: str,
cache_file: str = None,
deserialize: Callable[[str], Instance] = None,
serialize: Callable[[Instance], str] = None,
vocab: Vocabulary = None,
) -> None:
super().__init__()
self.instance_generator = instance_generator
self.file_path = file_path
self.cache_file = cache_file
self.deserialize = deserialize
self.serialize = serialize
Expand All @@ -67,15 +69,15 @@ def __iter__(self) -> Iterator[Instance]:
# Case 2: Need to cache instances
elif self.cache_file is not None:
with open(self.cache_file, "w") as data_file:
for instance in self.instance_generator():
for instance in self.instance_generator(self.file_path):
data_file.write(self.serialize(instance))
data_file.write("\n")
if self.vocab is not None:
instance.index_fields(self.vocab)
yield instance
# Case 3: No cache
else:
instances = self.instance_generator()
instances = self.instance_generator(self.file_path)
if isinstance(instances, list):
raise ConfigurationError(
"For a lazy dataset reader, _read() must return a generator"
Expand Down Expand Up @@ -176,7 +178,8 @@ def read(self, file_path: str) -> Dataset:

if lazy:
instances: Iterable[Instance] = _LazyInstances(
lambda: self._read(file_path),
self._read,
file_path,
cache_file,
self.deserialize_instance,
self.serialize_instance,
Expand Down

0 comments on commit d66db44

Please sign in to comment.