Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

pieces for multitask learning #2369

Merged
merged 6 commits into from
Jan 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from allennlp.data.dataset_readers.coreference_resolution import ConllCorefReader, WinobiasReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.event2mind import Event2MindDatasetReader
from allennlp.data.dataset_readers.interleaving_dataset_reader import InterleavingDatasetReader
from allennlp.data.dataset_readers.language_modeling import LanguageModelingReader
from allennlp.data.dataset_readers.multiprocess_dataset_reader import MultiprocessDatasetReader
from allennlp.data.dataset_readers.penn_tree_bank import PennTreeBankConstituencySpanDatasetReader
Expand Down
96 changes: 96 additions & 0 deletions allennlp/data/dataset_readers/interleaving_dataset_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Dict, Iterable
import json

from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import MetadataField
from allennlp.data.instance import Instance

_VALID_SCHEMES = {"round_robin", "all_at_once"}


@DatasetReader.register("interleaving")
class InterleavingDatasetReader(DatasetReader):
"""
A ``DatasetReader`` that wraps multiple other dataset readers,
and interleaves their instances, adding a ``MetadataField`` to
indicate the provenance of each instance.

Unlike most of our other dataset readers, here the ``file_path`` passed into
``read()`` should be a JSON-serialized dictionary with one file_path
per wrapped dataset reader (and with corresponding keys).

Parameters
----------
readers : ``Dict[str, DatasetReader]``
The dataset readers to wrap. The keys of this dictionary will be used
as the values in the MetadataField indicating provenance.
dataset_field_name : str, optional (default = "dataset")
The name of the MetadataField indicating which dataset an instance came from.
scheme : str, optional (default = "round_robin")
Indicates how to interleave instances. Currently the two options are "round_robin",
which repeatedly cycles through the datasets grabbing one instance from each;
and "all_at_once", which yields all the instances from the first dataset,
then all the instances from the second dataset, and so on. You could imagine also
implementing some sort of over- or under-sampling, although hasn't been done.
lazy : bool, optional (default = False)
If this is true, ``instances()`` will return an object whose ``__iter__`` method
reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list.
"""
def __init__(self,
readers: Dict[str, DatasetReader],
dataset_field_name: str = "dataset",
scheme: str = "round_robin",
lazy: bool = False) -> None:
super().__init__(lazy)
self._readers = readers
self._dataset_field_name = dataset_field_name

if scheme not in _VALID_SCHEMES:
raise ConfigurationError(f"invalid scheme: {scheme}")
self._scheme = scheme

def _read_round_robin(self, datasets: Dict[str, Iterable[Instance]]) -> Iterable[Instance]:
remaining = set(datasets)
dataset_iterators = {key: iter(dataset) for key, dataset in datasets.items()}

while remaining:
for key, dataset in dataset_iterators.items():
if key in remaining:
try:
instance = next(dataset)
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance
except StopIteration:
remaining.remove(key)

def _read_all_at_once(self, datasets: Dict[str, Iterable[Instance]]) -> Iterable[Instance]:
for key, dataset in datasets.items():
for instance in dataset:
instance.fields[self._dataset_field_name] = MetadataField(key)
yield instance


def _read(self, file_path: str) -> Iterable[Instance]:
try:
file_paths = json.loads(file_path)
except json.JSONDecodeError:
raise ConfigurationError("the file_path for the InterleavingDatasetReader "
"needs to be a JSON-serialized dictionary {reader_name -> file_path}")

if file_paths.keys() != self._readers.keys():
raise ConfigurationError("mismatched keys")

# Load datasets
datasets = {key: reader.read(file_paths[key]) for key, reader in self._readers.items()}

if self._scheme == "round_robin":
yield from self._read_round_robin(datasets)
elif self._scheme == "all_at_once":
yield from self._read_all_at_once(datasets)
else:
raise RuntimeError("impossible to get here")

def text_to_instance(self) -> Instance: # type: ignore
# pylint: disable=arguments-differ
raise RuntimeError("text_to_instance doesn't make sense here")
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
74 changes: 74 additions & 0 deletions allennlp/data/iterators/homogeneous_batch_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Iterable, Dict, List
import random
from collections import defaultdict

from allennlp.common.util import lazy_groups_of
from allennlp.data.dataset import Batch
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator

@DataIterator.register("homogeneous_batch")
class HomogeneousBatchIterator(DataIterator):
"""
This iterator takes a dataset of potentially heterogeneous instances
and yields back homogeneous batches. It assumes that each instance has
some ``MetadataField`` indicating what "type" of instance it is
and bases its notion of homogeneity on that (and, in particular, not on
inspecting the "field signature" of the instance.)

Parameters
----------
batch_size : ``int``, optional, (default = 32)
The size of each batch of instances yielded when calling the iterator.
instances_per_epoch : ``int``, optional, (default = None)
If specified, each epoch will consist of precisely this many instances.
If not specified, each epoch will consist of a single pass through the dataset.
max_instances_in_memory : ``int``, optional, (default = None)
If specified, the iterator will load this many instances at a time into an
in-memory list and then produce batches from one such list at a time. This
could be useful if your instances are read lazily from disk.
cache_instances : ``bool``, optional, (default = False)
If true, the iterator will cache the tensorized instances in memory.
If false, it will do the tensorization anew each iteration.
track_epoch : ``bool``, optional, (default = False)
If true, each instance will get a ``MetadataField`` containing the epoch number.
partition_key : ``str``, optional, (default = "dataset")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit gross, I wonder if it's better to allow setting an "origin" attribute on an Instance or something. Maybe not for this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it needs to make it into the model, so if you did that you'd have to change all the "batch to tensor" logic to account for that and then make sure it doesn't somehow collide with other tensors and so on

The key of the ``MetadataField`` indicating what "type" of instance this is.
"""
def __init__(self,
batch_size: int = 32,
instances_per_epoch: int = None,
max_instances_in_memory: int = None,
cache_instances: bool = False,
track_epoch: bool = False,
partition_key: str = "dataset") -> None:
super().__init__(batch_size, instances_per_epoch, max_instances_in_memory,
cache_instances, track_epoch)
self._partition_key = partition_key

def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
# First break the dataset into memory-sized lists:
for instance_list in self._memory_sized_lists(instances):
if shuffle:
random.shuffle(instance_list)

# Divvy up the instances based on their value of the "partition_key" field.
hoppers: Dict[str, List[Instance]] = defaultdict(list)
for instance in instance_list:
partition = instance.fields[self._partition_key].metadata # type: ignore
hoppers[partition].append(instance)

# Get a `lazy_groups_of` iterator over each set of homogeneous instances.
batches = {key: lazy_groups_of(iter(hopper), self._batch_size) for key, hopper in hoppers.items()}

remaining = set(batches)

# Yield batches in a round-robin fashion until none are left.
while remaining:
for key, lazy_batches in batches.items():
if key in remaining:
try:
batch = next(lazy_batches)
yield Batch(batch)
except StopIteration:
remaining.remove(key)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Iterable

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.dataset_readers import DatasetReader, InterleavingDatasetReader
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import WordTokenizer


class PlainTextReader(DatasetReader):
def __init__(self):
super().__init__()
self._token_indexers = {"tokens": SingleIdTokenIndexer()}
self._tokenizer = WordTokenizer()

def _read(self, file_path: str) -> Iterable[Instance]:
with open(file_path) as input_file:
for line in input_file:
yield self.text_to_instance(line)

def text_to_instance(self, line: str) -> Instance: # type: ignore
# pylint: disable=arguments-differ
tokens = self._tokenizer.tokenize(line)
return Instance({"line": TextField(tokens, self._token_indexers)})


class TestInterleavingDatasetReader(AllenNlpTestCase):
def test_round_robin(self):
readers = {
"a": PlainTextReader(),
"b": PlainTextReader(),
"c": PlainTextReader()
}

reader = InterleavingDatasetReader(readers)
data_dir = self.FIXTURES_ROOT / "data"

file_path = f"""{{
"a": "{data_dir / 'babi.txt'}",
"b": "{data_dir / 'conll2000.txt'}",
"c": "{data_dir / 'conll2003.txt'}"
}}"""

instances = list(reader.read(file_path))
first_three_keys = {instance.fields["dataset"].metadata for instance in instances[:3]}
assert first_three_keys == {"a", "b", "c"}

next_three_keys = {instance.fields["dataset"].metadata for instance in instances[3:6]}
assert next_three_keys == {"a", "b", "c"}

def test_all_at_once(self):
readers = {
"f": PlainTextReader(),
"g": PlainTextReader(),
"h": PlainTextReader()
}

reader = InterleavingDatasetReader(readers, dataset_field_name="source", scheme="all_at_once")
data_dir = self.FIXTURES_ROOT / "data"

file_path = f"""{{
"f": "{data_dir / 'babi.txt'}",
"g": "{data_dir / 'conll2000.txt'}",
"h": "{data_dir / 'conll2003.txt'}"
}}"""

buckets = []
last_source = None

# Fill up a bucket until the source changes, then start a new one
for instance in reader.read(file_path):
source = instance.fields["source"].metadata
if source != last_source:
buckets.append([])
last_source = source
buckets[-1].append(instance)

# should be in 3 buckets
assert len(buckets) == 3
45 changes: 45 additions & 0 deletions allennlp/tests/data/iterators/homogeneous_iterator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections import Counter

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.dataset_readers import InterleavingDatasetReader
from allennlp.data.iterators import HomogeneousBatchIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.tests.data.dataset_readers.interleaving_dataset_reader_test import PlainTextReader


class TestHomogeneousBatchIterator(AllenNlpTestCase):
def test_batches(self):
readers = {
"a": PlainTextReader(),
"b": PlainTextReader(),
"c": PlainTextReader()
}

reader = InterleavingDatasetReader(readers)
data_dir = self.FIXTURES_ROOT / "data"

file_path = f"""{{
"a": "{data_dir / 'babi.txt'}",
"b": "{data_dir / 'conll2000.txt'}",
"c": "{data_dir / 'conll2003.txt'}"
}}"""

instances = list(reader.read(file_path))
vocab = Vocabulary.from_instances(instances)

actual_instance_type_counts = Counter(instance.fields["dataset"].metadata
for instance in instances)

iterator = HomogeneousBatchIterator(batch_size=3)
iterator.index_with(vocab)

observed_instance_type_counts = Counter()

for batch in iterator(instances, num_epochs=1, shuffle=True):
# batch should be homogeneous
instance_types = set(batch["dataset"])
assert len(instance_types) == 1

observed_instance_type_counts.update(batch["dataset"])

assert observed_instance_type_counts == actual_instance_type_counts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
allennlp.data.dataset_readers.interleaving_dataset_reader
=========================================================

.. automodule:: allennlp.data.dataset_readers.interleaving_dataset_reader
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions doc/api/allennlp.data.dataset_readers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ allennlp.data.dataset_readers
allennlp.data.dataset_readers.conll2003
allennlp.data.dataset_readers.coreference_resolution
allennlp.data.dataset_readers.event2mind
allennlp.data.dataset_readers.interleaving_dataset_reader
allennlp.data.dataset_readers.language_modeling
allennlp.data.dataset_readers.multiprocess_dataset_reader
allennlp.data.dataset_readers.ontonotes_ner
Expand Down
8 changes: 8 additions & 0 deletions doc/api/allennlp.data.iterators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ allennlp.data.iterators
* :ref:`BasicIterator<basic-iterator>`
* :ref:`BucketIterator<bucket-iterator>`
* :ref:`MultiprocessIterator<multiprocess-iterator>`
* :ref:`HomogeneousBatchIterator<homogeneous-batch-iterator>`

.. _data-iterator:
.. automodule:: allennlp.data.iterators.data_iterator
Expand All @@ -34,3 +35,10 @@ allennlp.data.iterators
:members:
:undoc-members:
:show-inheritance:

.. _homogeneous-batch-iterator:
.. automodule:: allennlp.data.iterators.homogeneous_batch_iterator
:members:
:undoc-members:
:show-inheritance: