This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added PassThroughIterator * Added test for PassThroughIterator * Add @OVERRIDES and appease mypy. * Appease pylint and mypy. * Added new iterator to docs (I think...) * Opted for simplified implementation * Appease pylint * Typo * Added back in overrides decorator
- Loading branch information
1 parent
70fa4aa
commit 15a9cbe
Showing
4 changed files
with
90 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Iterable, Iterator | ||
import itertools | ||
import logging | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.data.dataset import Batch | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DataIterator.register("pass_through") | ||
class PassThroughIterator(DataIterator): | ||
""" | ||
An iterator which performs no batching or shuffling of instances, only tensorization. E.g, | ||
instances are effectively passed 'straight through' the iterator. | ||
This is essentially the same as a BasicIterator with shuffling disabled, the batch size set | ||
to 1, and maximum samples per batch disabled. The only difference is that this iterator | ||
removes the batch dimension. This can be useful for rare situations where batching is best | ||
performed within the dataset reader (e.g. for contiguous language modeling, or for other | ||
problems where state is shared across batches). | ||
""" | ||
def __init__(self): | ||
super().__init__(batch_size=1) | ||
|
||
@overrides | ||
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: | ||
raise RuntimeError("PassThroughIterator doesn't use create_batches") | ||
|
||
def __call__(self, | ||
instances: Iterable[Instance], | ||
num_epochs: int = None, | ||
shuffle: bool = False) -> Iterator[TensorDict]: | ||
# Warn users that this iterator does not do anything for you. | ||
if shuffle: | ||
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is " | ||
"required, please implement in your DatasetReader.") | ||
|
||
if num_epochs is None: | ||
epochs: Iterable[int] = itertools.count() | ||
else: | ||
epochs = range(num_epochs) | ||
|
||
for _ in epochs: | ||
for instance in instances: | ||
if self.vocab is not None: | ||
instance.index_fields(self.vocab) | ||
yield instance.as_tensor_dict() |
30 changes: 30 additions & 0 deletions
30
allennlp/tests/data/iterators/pass_through_iterator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
from allennlp.data.iterators.pass_through_iterator import PassThroughIterator, logger | ||
from allennlp.tests.data.iterators.basic_iterator_test import IteratorTest | ||
|
||
|
||
class TestPassThroughIterator(IteratorTest): | ||
def test_get_num_batches(self): | ||
# Since batching is assumed to be performed in the DatasetReader, the number of batches | ||
# (according to the iterator) should always equal the number of instances. | ||
self.assertEqual(PassThroughIterator().get_num_batches(self.instances), | ||
len(self.instances)) | ||
|
||
def test_enabling_shuffling_raises_warning(self): | ||
iterator = PassThroughIterator() | ||
iterator.index_with(self.vocab) | ||
generator = iterator(self.instances, shuffle=True) | ||
with self.assertLogs(logger, level='INFO') as context_manager: | ||
next(generator) | ||
self.assertIn('WARNING', context_manager.output[0]) | ||
|
||
def test_batch_dim_is_removed(self): | ||
# Ensure that PassThroughIterator does not add a batch dimension to tensors. | ||
|
||
# First instance is a sequence of four tokens. Thus the expected output is a dict | ||
# containing a single tensor with shape (4,). | ||
iterator = PassThroughIterator() | ||
iterator.index_with(self.vocab) | ||
generator = iterator(self.instances) | ||
tensor_dict = next(generator) | ||
self.assertEqual(tensor_dict['text']['tokens'].size(), (4,)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters