-
Notifications
You must be signed in to change notification settings - Fork 2.3k
PassThroughIterator #3015
PassThroughIterator #3015
Changes from 4 commits
618c94a
ba6f6d8
551575f
21d2c63
138ba06
ddc37aa
3c58928
a891f2d
afc7c8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Iterable, Iterator, Union | ||
import logging | ||
|
||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.data.dataset import Batch | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.basic_iterator import BasicIterator | ||
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
def _remove_batch_dim(singleton: Union[TensorDict, torch.Tensor]) -> Union[TensorDict, torch.Tensor]: | ||
"""Recursively removes the batch dimension from tensors in a TensorDict.""" | ||
if isinstance(singleton, dict): | ||
return {key: _remove_batch_dim(value) for key, value in singleton.items()} # type: ignore | ||
elif isinstance(singleton, torch.Tensor): | ||
return singleton.squeeze(0) | ||
# TODO(rloganiv): Not sure if this is appropriate for Fields whose as_tensor and batch_tensor | ||
# methods do not return DataArrays (e.g. MetadataField and ProductionRuleField). | ||
else: | ||
return singleton | ||
|
||
|
||
@DataIterator.register("pass_through") | ||
class PassThroughIterator(BasicIterator): | ||
""" | ||
An iterator which performs no batching or shuffling of instances, only tensorization. E.g, | ||
instances are effectively passed 'straight through' the iterator. | ||
|
||
This is essentially the same as a BasicIterator with shuffling disabled, the batch size set | ||
to 1, and maximum sampled per batch disabled. The only difference is that this iterator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
removes the batch dimension. This can be useful for rare situations where batching is best | ||
performed within the dataset reader (e.g. for continguous language modeling, or for other | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
problems where state is shared across batches). | ||
|
||
Parameters | ||
---------- | ||
instances_per_epoch : ``int``, optional, (default = None) | ||
If specified, each epoch will consist of precisely this many instances. | ||
If not specified, each epoch will consist of a single pass through the dataset. | ||
max_instances_in_memory : ``int``, optional, (default = None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mainly for bucket iterators, so we can get more instances into memory before we sort them by size, so you're more likely to have consistently-sized batches. I don't think you need this parameter here, because you're handling this in the |
||
If specified, the iterator will load this many instances at a time into an | ||
in-memory list and then produce batches from one such list at a time. This | ||
could be useful if your instances are read lazily from disk. | ||
cache_instances : ``bool``, optional, (default = False) | ||
If true, the iterator will cache the tensorized instances in memory. | ||
If false, it will do the tensorization anew each iteration. | ||
track_epoch : ``bool``, optional, (default = False) | ||
If true, each instance will get a ``MetadataField`` containing the epoch number. | ||
""" | ||
def __init__(self, | ||
instances_per_epoch: int = None, | ||
max_instances_in_memory: int = None, | ||
cache_instances: bool = False, | ||
track_epoch: bool = False) -> None: | ||
super().__init__(batch_size=1, | ||
instances_per_epoch=instances_per_epoch, | ||
max_instances_in_memory=max_instances_in_memory, | ||
cache_instances=cache_instances, | ||
track_epoch=track_epoch, | ||
maximum_samples_per_batch=None) | ||
|
||
@overrides | ||
def __call__(self, | ||
instances: Iterable[Instance], | ||
num_epochs: int = None, | ||
shuffle: bool = False) -> Iterator[TensorDict]: | ||
if shuffle: | ||
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is " | ||
"required, please implement in your DatasetReader.") | ||
shuffle = False | ||
for tensor_dict in super().__call__(instances, num_epochs, shuffle): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could make this even simpler and remove the need for def __call__(self, instances, num_epochs, shuffle):
for epoch in num_epochs: # handle num_epochs == None here
for instance in instances:
yield instance.as_tensor_dict() This means you don't get caching or epoch tracking, but it simplifies a lot of other things. I'm not sure whether we should do it this way or not, just something to think about. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this approach. It handles the issue |
||
yield _remove_batch_dim(tensor_dict) | ||
|
||
@overrides | ||
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: | ||
if shuffle: | ||
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is " | ||
"required, please implement in your DatasetReader.") | ||
shuffle = False | ||
yield from super()._create_batches(instances, shuffle) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
import numpy as np | ||
import torch | ||
|
||
from allennlp.data.iterators import PassThroughIterator | ||
from allennlp.data.iterators.pass_through_iterator import _remove_batch_dim, logger | ||
from allennlp.tests.data.iterators.basic_iterator_test import IteratorTest | ||
|
||
|
||
def test_remove_batch_dim(): | ||
# Check that first dimension of a tensor is removed | ||
tensor_with_extra_batch_dim = torch.LongTensor([[1, 2, 3, 4]]) | ||
observed_output = _remove_batch_dim(tensor_with_extra_batch_dim).data.numpy() | ||
expected_output = np.array([1, 2, 3, 4]) | ||
np.testing.assert_almost_equal(observed_output, expected_output) | ||
|
||
# Check that first dimension of a tensor in a dictionary is removed | ||
tensor_dict_with_extra_batch_dim = {'tensor': tensor_with_extra_batch_dim} | ||
observed_output = _remove_batch_dim(tensor_dict_with_extra_batch_dim) | ||
np.testing.assert_almost_equal(observed_output['tensor'].data.numpy(), | ||
expected_output) | ||
|
||
# Chek that other input types are unaffected | ||
non_tensor = 'should be ignored' | ||
assert _remove_batch_dim(non_tensor) | ||
|
||
dict_with_non_tensor = {'non_tensor': non_tensor} | ||
assert _remove_batch_dim(dict_with_non_tensor) == dict_with_non_tensor | ||
|
||
|
||
class TestPassThroughIterator(IteratorTest): | ||
def test_get_num_batches(self): | ||
# Since batching is assumed to be performed in the DatasetReader, the number of batches | ||
# (according to the iterator) should always equal the number of instances. | ||
self.assertEqual(PassThroughIterator().get_num_batches(self.instances), | ||
len(self.instances)) | ||
|
||
def test_enabling_shuffling_raises_warning(self): | ||
iterator = PassThroughIterator() | ||
iterator.index_with(self.vocab) | ||
generator = iterator(self.instances, shuffle=True) | ||
with self.assertLogs(logger, level='INFO') as context_manager: | ||
next(generator) | ||
self.assertIn('WARNING', context_manager.output[0]) | ||
|
||
def test_batch_dim_is_removed(self): | ||
# Ensure that PassThroughIterator does not add a batch dimension to tensors. | ||
|
||
# First instance is a sequence of four tokens. Thus the expected output is a dict | ||
# containing a single tensor with shape (4,). | ||
iterator = PassThroughIterator() | ||
iterator.index_with(self.vocab) | ||
generator = iterator(self.instances) | ||
tensor_dict = next(generator) | ||
self.assertEqual(tensor_dict['text']['tokens'].size(), (4,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the easiest thing to do about this would be have a test that checks that a reasonable thing happens for a
MetadataField
(don't worry too much about theProductionRuleField
).