This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
pieces for multitask learning #2369
Merged
joelgrus
merged 6 commits into
allenai:master
from
joelgrus:interleaving-dataset-reader
Jan 17, 2019
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
2510fa1
pieces for multitask learning
joelgrus ac4bfeb
mypy
joelgrus 0bd7e61
add docs
joelgrus 82fc479
fix docs
joelgrus fa7d86a
dashes to underscores
joelgrus 85e3f46
Merge branch 'master' into interleaving-dataset-reader
joelgrus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
allennlp/data/dataset_readers/interleaving_dataset_reader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Dict, Iterable | ||
import json | ||
|
||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.fields import MetadataField | ||
from allennlp.data.instance import Instance | ||
|
||
_VALID_SCHEMES = {"round_robin", "all_at_once"} | ||
|
||
|
||
@DatasetReader.register("interleaving") | ||
class InterleavingDatasetReader(DatasetReader): | ||
""" | ||
A ``DatasetReader`` that wraps multiple other dataset readers, | ||
and interleaves their instances, adding a ``MetadataField`` to | ||
indicate the provenance of each instance. | ||
|
||
Unlike most of our other dataset readers, here the ``file_path`` passed into | ||
``read()`` should be a JSON-serialized dictionary with one file_path | ||
per wrapped dataset reader (and with corresponding keys). | ||
|
||
Parameters | ||
---------- | ||
readers : ``Dict[str, DatasetReader]`` | ||
The dataset readers to wrap. The keys of this dictionary will be used | ||
as the values in the MetadataField indicating provenance. | ||
dataset_field_name : str, optional (default = "dataset") | ||
The name of the MetadataField indicating which dataset an instance came from. | ||
scheme : str, optional (default = "round_robin") | ||
Indicates how to interleave instances. Currently the two options are "round_robin", | ||
which repeatedly cycles through the datasets grabbing one instance from each; | ||
and "all_at_once", which yields all the instances from the first dataset, | ||
then all the instances from the second dataset, and so on. You could imagine also | ||
implementing some sort of over- or under-sampling, although hasn't been done. | ||
lazy : bool, optional (default = False) | ||
If this is true, ``instances()`` will return an object whose ``__iter__`` method | ||
reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list. | ||
""" | ||
def __init__(self, | ||
readers: Dict[str, DatasetReader], | ||
dataset_field_name: str = "dataset", | ||
scheme: str = "round_robin", | ||
lazy: bool = False) -> None: | ||
super().__init__(lazy) | ||
self._readers = readers | ||
self._dataset_field_name = dataset_field_name | ||
|
||
if scheme not in _VALID_SCHEMES: | ||
raise ConfigurationError(f"invalid scheme: {scheme}") | ||
self._scheme = scheme | ||
|
||
def _read_round_robin(self, datasets: Dict[str, Iterable[Instance]]) -> Iterable[Instance]: | ||
remaining = set(datasets) | ||
dataset_iterators = {key: iter(dataset) for key, dataset in datasets.items()} | ||
|
||
while remaining: | ||
for key, dataset in dataset_iterators.items(): | ||
if key in remaining: | ||
try: | ||
instance = next(dataset) | ||
instance.fields[self._dataset_field_name] = MetadataField(key) | ||
yield instance | ||
except StopIteration: | ||
remaining.remove(key) | ||
|
||
def _read_all_at_once(self, datasets: Dict[str, Iterable[Instance]]) -> Iterable[Instance]: | ||
for key, dataset in datasets.items(): | ||
for instance in dataset: | ||
instance.fields[self._dataset_field_name] = MetadataField(key) | ||
yield instance | ||
|
||
|
||
def _read(self, file_path: str) -> Iterable[Instance]: | ||
try: | ||
file_paths = json.loads(file_path) | ||
except json.JSONDecodeError: | ||
raise ConfigurationError("the file_path for the InterleavingDatasetReader " | ||
"needs to be a JSON-serialized dictionary {reader_name -> file_path}") | ||
|
||
if file_paths.keys() != self._readers.keys(): | ||
raise ConfigurationError("mismatched keys") | ||
|
||
# Load datasets | ||
datasets = {key: reader.read(file_paths[key]) for key, reader in self._readers.items()} | ||
|
||
if self._scheme == "round_robin": | ||
yield from self._read_round_robin(datasets) | ||
elif self._scheme == "all_at_once": | ||
yield from self._read_all_at_once(datasets) | ||
else: | ||
raise RuntimeError("impossible to get here") | ||
|
||
def text_to_instance(self) -> Instance: # type: ignore | ||
# pylint: disable=arguments-differ | ||
raise RuntimeError("text_to_instance doesn't make sense here") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Iterable, Dict, List | ||
import random | ||
from collections import defaultdict | ||
|
||
from allennlp.common.util import lazy_groups_of | ||
from allennlp.data.dataset import Batch | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.data_iterator import DataIterator | ||
|
||
@DataIterator.register("homogeneous_batch") | ||
class HomogeneousBatchIterator(DataIterator): | ||
""" | ||
This iterator takes a dataset of potentially heterogeneous instances | ||
and yields back homogeneous batches. It assumes that each instance has | ||
some ``MetadataField`` indicating what "type" of instance it is | ||
and bases its notion of homogeneity on that (and, in particular, not on | ||
inspecting the "field signature" of the instance.) | ||
|
||
Parameters | ||
---------- | ||
batch_size : ``int``, optional, (default = 32) | ||
The size of each batch of instances yielded when calling the iterator. | ||
instances_per_epoch : ``int``, optional, (default = None) | ||
If specified, each epoch will consist of precisely this many instances. | ||
If not specified, each epoch will consist of a single pass through the dataset. | ||
max_instances_in_memory : ``int``, optional, (default = None) | ||
If specified, the iterator will load this many instances at a time into an | ||
in-memory list and then produce batches from one such list at a time. This | ||
could be useful if your instances are read lazily from disk. | ||
cache_instances : ``bool``, optional, (default = False) | ||
If true, the iterator will cache the tensorized instances in memory. | ||
If false, it will do the tensorization anew each iteration. | ||
track_epoch : ``bool``, optional, (default = False) | ||
If true, each instance will get a ``MetadataField`` containing the epoch number. | ||
partition_key : ``str``, optional, (default = "dataset") | ||
The key of the ``MetadataField`` indicating what "type" of instance this is. | ||
""" | ||
def __init__(self, | ||
batch_size: int = 32, | ||
instances_per_epoch: int = None, | ||
max_instances_in_memory: int = None, | ||
cache_instances: bool = False, | ||
track_epoch: bool = False, | ||
partition_key: str = "dataset") -> None: | ||
super().__init__(batch_size, instances_per_epoch, max_instances_in_memory, | ||
cache_instances, track_epoch) | ||
self._partition_key = partition_key | ||
|
||
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: | ||
# First break the dataset into memory-sized lists: | ||
for instance_list in self._memory_sized_lists(instances): | ||
if shuffle: | ||
random.shuffle(instance_list) | ||
|
||
# Divvy up the instances based on their value of the "partition_key" field. | ||
hoppers: Dict[str, List[Instance]] = defaultdict(list) | ||
for instance in instance_list: | ||
partition = instance.fields[self._partition_key].metadata # type: ignore | ||
hoppers[partition].append(instance) | ||
|
||
# Get a `lazy_groups_of` iterator over each set of homogeneous instances. | ||
batches = {key: lazy_groups_of(iter(hopper), self._batch_size) for key, hopper in hoppers.items()} | ||
|
||
remaining = set(batches) | ||
|
||
# Yield batches in a round-robin fashion until none are left. | ||
while remaining: | ||
for key, lazy_batches in batches.items(): | ||
if key in remaining: | ||
try: | ||
batch = next(lazy_batches) | ||
yield Batch(batch) | ||
except StopIteration: | ||
remaining.remove(key) |
80 changes: 80 additions & 0 deletions
80
allennlp/tests/data/dataset_readers/interleaving_dataset_reader_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Iterable | ||
|
||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.data.dataset_readers import DatasetReader, InterleavingDatasetReader | ||
from allennlp.data.fields import TextField | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer | ||
from allennlp.data.tokenizers import WordTokenizer | ||
|
||
|
||
class PlainTextReader(DatasetReader): | ||
def __init__(self): | ||
super().__init__() | ||
self._token_indexers = {"tokens": SingleIdTokenIndexer()} | ||
self._tokenizer = WordTokenizer() | ||
|
||
def _read(self, file_path: str) -> Iterable[Instance]: | ||
with open(file_path) as input_file: | ||
for line in input_file: | ||
yield self.text_to_instance(line) | ||
|
||
def text_to_instance(self, line: str) -> Instance: # type: ignore | ||
# pylint: disable=arguments-differ | ||
tokens = self._tokenizer.tokenize(line) | ||
return Instance({"line": TextField(tokens, self._token_indexers)}) | ||
|
||
|
||
class TestInterleavingDatasetReader(AllenNlpTestCase): | ||
def test_round_robin(self): | ||
readers = { | ||
"a": PlainTextReader(), | ||
"b": PlainTextReader(), | ||
"c": PlainTextReader() | ||
} | ||
|
||
reader = InterleavingDatasetReader(readers) | ||
data_dir = self.FIXTURES_ROOT / "data" | ||
|
||
file_path = f"""{{ | ||
"a": "{data_dir / 'babi.txt'}", | ||
"b": "{data_dir / 'conll2000.txt'}", | ||
"c": "{data_dir / 'conll2003.txt'}" | ||
}}""" | ||
|
||
instances = list(reader.read(file_path)) | ||
first_three_keys = {instance.fields["dataset"].metadata for instance in instances[:3]} | ||
assert first_three_keys == {"a", "b", "c"} | ||
|
||
next_three_keys = {instance.fields["dataset"].metadata for instance in instances[3:6]} | ||
assert next_three_keys == {"a", "b", "c"} | ||
|
||
def test_all_at_once(self): | ||
readers = { | ||
"f": PlainTextReader(), | ||
"g": PlainTextReader(), | ||
"h": PlainTextReader() | ||
} | ||
|
||
reader = InterleavingDatasetReader(readers, dataset_field_name="source", scheme="all_at_once") | ||
data_dir = self.FIXTURES_ROOT / "data" | ||
|
||
file_path = f"""{{ | ||
"f": "{data_dir / 'babi.txt'}", | ||
"g": "{data_dir / 'conll2000.txt'}", | ||
"h": "{data_dir / 'conll2003.txt'}" | ||
}}""" | ||
|
||
buckets = [] | ||
last_source = None | ||
|
||
# Fill up a bucket until the source changes, then start a new one | ||
for instance in reader.read(file_path): | ||
source = instance.fields["source"].metadata | ||
if source != last_source: | ||
buckets.append([]) | ||
last_source = source | ||
buckets[-1].append(instance) | ||
|
||
# should be in 3 buckets | ||
assert len(buckets) == 3 |
45 changes: 45 additions & 0 deletions
45
allennlp/tests/data/iterators/homogeneous_iterator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from collections import Counter | ||
|
||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.data.dataset_readers import InterleavingDatasetReader | ||
from allennlp.data.iterators import HomogeneousBatchIterator | ||
from allennlp.data.vocabulary import Vocabulary | ||
from allennlp.tests.data.dataset_readers.interleaving_dataset_reader_test import PlainTextReader | ||
|
||
|
||
class TestHomogeneousBatchIterator(AllenNlpTestCase): | ||
def test_batches(self): | ||
readers = { | ||
"a": PlainTextReader(), | ||
"b": PlainTextReader(), | ||
"c": PlainTextReader() | ||
} | ||
|
||
reader = InterleavingDatasetReader(readers) | ||
data_dir = self.FIXTURES_ROOT / "data" | ||
|
||
file_path = f"""{{ | ||
"a": "{data_dir / 'babi.txt'}", | ||
"b": "{data_dir / 'conll2000.txt'}", | ||
"c": "{data_dir / 'conll2003.txt'}" | ||
}}""" | ||
|
||
instances = list(reader.read(file_path)) | ||
vocab = Vocabulary.from_instances(instances) | ||
|
||
actual_instance_type_counts = Counter(instance.fields["dataset"].metadata | ||
for instance in instances) | ||
|
||
iterator = HomogeneousBatchIterator(batch_size=3) | ||
iterator.index_with(vocab) | ||
|
||
observed_instance_type_counts = Counter() | ||
|
||
for batch in iterator(instances, num_epochs=1, shuffle=True): | ||
# batch should be homogeneous | ||
instance_types = set(batch["dataset"]) | ||
assert len(instance_types) == 1 | ||
|
||
observed_instance_type_counts.update(batch["dataset"]) | ||
|
||
assert observed_instance_type_counts == actual_instance_type_counts |
7 changes: 7 additions & 0 deletions
7
doc/api/allennlp.data.dataset_readers.interleaving_dataset_reader.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
allennlp.data.dataset_readers.interleaving_dataset_reader | ||
========================================================= | ||
|
||
.. automodule:: allennlp.data.dataset_readers.interleaving_dataset_reader | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit gross, I wonder if it's better to allow setting an "origin" attribute on an
Instance
or something. Maybe not for this PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it needs to make it into the model, so if you did that you'd have to change all the "batch to tensor" logic to account for that and then make sure it doesn't somehow collide with other tensors and so on