Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

PassThroughIterator #3015

Merged
merged 9 commits into from
Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
from allennlp.data.iterators.pass_through_iterator import PassThroughIterator
from allennlp.data.iterators.same_language_iterator import SameLanguageIterator
84 changes: 84 additions & 0 deletions allennlp/data/iterators/pass_through_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Iterable, Iterator, Union
import logging

from overrides import overrides
import torch

from allennlp.data.dataset import Batch
from allennlp.data.instance import Instance
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def _remove_batch_dim(singleton: Union[TensorDict, torch.Tensor]) -> Union[TensorDict, torch.Tensor]:
"""Recursively removes the batch dimension from tensors in a TensorDict."""
if isinstance(singleton, dict):
return {key: _remove_batch_dim(value) for key, value in singleton.items()} # type: ignore
elif isinstance(singleton, torch.Tensor):
return singleton.squeeze(0)
# TODO(rloganiv): Not sure if this is appropriate for Fields whose as_tensor and batch_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest thing to do about this would be have a test that checks that a reasonable thing happens for a MetadataField (don't worry too much about the ProductionRuleField).

# methods do not return DataArrays (e.g. MetadataField and ProductionRuleField).
else:
return singleton


@DataIterator.register("pass_through")
class PassThroughIterator(BasicIterator):
"""
An iterator which performs no batching or shuffling of instances, only tensorization. E.g,
instances are effectively passed 'straight through' the iterator.

This is essentially the same as a BasicIterator with shuffling disabled, the batch size set
to 1, and maximum sampled per batch disabled. The only difference is that this iterator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/maximum sampled/maximum samples/?

removes the batch dimension. This can be useful for rare situations where batching is best
performed within the dataset reader (e.g. for continguous language modeling, or for other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/continguous/continuous/

problems where state is shared across batches).

Parameters
----------
instances_per_epoch : ``int``, optional, (default = None)
If specified, each epoch will consist of precisely this many instances.
If not specified, each epoch will consist of a single pass through the dataset.
max_instances_in_memory : ``int``, optional, (default = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly for bucket iterators, so we can get more instances into memory before we sort them by size, so you're more likely to have consistently-sized batches. I don't think you need this parameter here, because you're handling this in the DatasetReader, not the base DataIterator.

If specified, the iterator will load this many instances at a time into an
in-memory list and then produce batches from one such list at a time. This
could be useful if your instances are read lazily from disk.
cache_instances : ``bool``, optional, (default = False)
If true, the iterator will cache the tensorized instances in memory.
If false, it will do the tensorization anew each iteration.
track_epoch : ``bool``, optional, (default = False)
If true, each instance will get a ``MetadataField`` containing the epoch number.
"""
def __init__(self,
instances_per_epoch: int = None,
max_instances_in_memory: int = None,
cache_instances: bool = False,
track_epoch: bool = False) -> None:
super().__init__(batch_size=1,
instances_per_epoch=instances_per_epoch,
max_instances_in_memory=max_instances_in_memory,
cache_instances=cache_instances,
track_epoch=track_epoch,
maximum_samples_per_batch=None)

@overrides
def __call__(self,
instances: Iterable[Instance],
num_epochs: int = None,
shuffle: bool = False) -> Iterator[TensorDict]:
if shuffle:
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is "
"required, please implement in your DatasetReader.")
shuffle = False
for tensor_dict in super().__call__(instances, num_epochs, shuffle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make this even simpler and remove the need for _remove_batch_dim by just doing:

def __call__(self, instances, num_epochs, shuffle):
    for epoch in num_epochs:  # handle num_epochs == None here
        for instance in instances:
            yield instance.as_tensor_dict()

This means you don't get caching or epoch tracking, but it simplifies a lot of other things. I'm not sure whether we should do it this way or not, just something to think about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach. It handles the issue _remove_batch_dim has with non-tensor inputs. Also caching will probably not be needed in most use cases since it is expected that the dataset reader will perform actions like shuffling, perturbing sequence lengths, etc.

yield _remove_batch_dim(tensor_dict)

@overrides
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
if shuffle:
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is "
"required, please implement in your DatasetReader.")
shuffle = False
yield from super()._create_batches(instances, shuffle)
55 changes: 55 additions & 0 deletions allennlp/tests/data/iterators/pass_through_iterator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pylint: disable=no-self-use,invalid-name
import numpy as np
import torch

from allennlp.data.iterators import PassThroughIterator
from allennlp.data.iterators.pass_through_iterator import _remove_batch_dim, logger
from allennlp.tests.data.iterators.basic_iterator_test import IteratorTest


def test_remove_batch_dim():
# Check that first dimension of a tensor is removed
tensor_with_extra_batch_dim = torch.LongTensor([[1, 2, 3, 4]])
observed_output = _remove_batch_dim(tensor_with_extra_batch_dim).data.numpy()
expected_output = np.array([1, 2, 3, 4])
np.testing.assert_almost_equal(observed_output, expected_output)

# Check that first dimension of a tensor in a dictionary is removed
tensor_dict_with_extra_batch_dim = {'tensor': tensor_with_extra_batch_dim}
observed_output = _remove_batch_dim(tensor_dict_with_extra_batch_dim)
np.testing.assert_almost_equal(observed_output['tensor'].data.numpy(),
expected_output)

# Chek that other input types are unaffected
non_tensor = 'should be ignored'
assert _remove_batch_dim(non_tensor)

dict_with_non_tensor = {'non_tensor': non_tensor}
assert _remove_batch_dim(dict_with_non_tensor) == dict_with_non_tensor


class TestPassThroughIterator(IteratorTest):
def test_get_num_batches(self):
# Since batching is assumed to be performed in the DatasetReader, the number of batches
# (according to the iterator) should always equal the number of instances.
self.assertEqual(PassThroughIterator().get_num_batches(self.instances),
len(self.instances))

def test_enabling_shuffling_raises_warning(self):
iterator = PassThroughIterator()
iterator.index_with(self.vocab)
generator = iterator(self.instances, shuffle=True)
with self.assertLogs(logger, level='INFO') as context_manager:
next(generator)
self.assertIn('WARNING', context_manager.output[0])

def test_batch_dim_is_removed(self):
# Ensure that PassThroughIterator does not add a batch dimension to tensors.

# First instance is a sequence of four tokens. Thus the expected output is a dict
# containing a single tensor with shape (4,).
iterator = PassThroughIterator()
iterator.index_with(self.vocab)
generator = iterator(self.instances)
tensor_dict = next(generator)
self.assertEqual(tensor_dict['text']['tokens'].size(), (4,))