-
Notifications
You must be signed in to change notification settings - Fork 2.3k
PassThroughIterator #3015
PassThroughIterator #3015
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Largely looks good. You still need to add this to the docs; let me know if you don't know how to do that.
return {key: _remove_batch_dim(value) for key, value in singleton.items()} # type: ignore | ||
elif isinstance(singleton, torch.Tensor): | ||
return singleton.squeeze(0) | ||
# TODO(rloganiv): Not sure if this is appropriate for Fields whose as_tensor and batch_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the easiest thing to do about this would be have a test that checks that a reasonable thing happens for a MetadataField
(don't worry too much about the ProductionRuleField
).
instances are effectively passed 'straight through' the iterator. | ||
|
||
This is essentially the same as a BasicIterator with shuffling disabled, the batch size set | ||
to 1, and maximum sampled per batch disabled. The only difference is that this iterator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/maximum sampled/maximum samples/
?
This is essentially the same as a BasicIterator with shuffling disabled, the batch size set | ||
to 1, and maximum sampled per batch disabled. The only difference is that this iterator | ||
removes the batch dimension. This can be useful for rare situations where batching is best | ||
performed within the dataset reader (e.g. for continguous language modeling, or for other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/continguous/continuous/
instances_per_epoch : ``int``, optional, (default = None) | ||
If specified, each epoch will consist of precisely this many instances. | ||
If not specified, each epoch will consist of a single pass through the dataset. | ||
max_instances_in_memory : ``int``, optional, (default = None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mainly for bucket iterators, so we can get more instances into memory before we sort them by size, so you're more likely to have consistently-sized batches. I don't think you need this parameter here, because you're handling this in the DatasetReader
, not the base DataIterator
.
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is " | ||
"required, please implement in your DatasetReader.") | ||
shuffle = False | ||
for tensor_dict in super().__call__(instances, num_epochs, shuffle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could make this even simpler and remove the need for _remove_batch_dim
by just doing:
def __call__(self, instances, num_epochs, shuffle):
for epoch in num_epochs: # handle num_epochs == None here
for instance in instances:
yield instance.as_tensor_dict()
This means you don't get caching or epoch tracking, but it simplifies a lot of other things. I'm not sure whether we should do it this way or not, just something to think about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this approach. It handles the issue _remove_batch_dim
has with non-tensor inputs. Also caching will probably not be needed in most use cases since it is expected that the dataset reader will perform actions like shuffling, perturbing sequence lengths, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
def __init__(self): | ||
super().__init__(batch_size=1) | ||
|
||
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping an @overrides
tag here would be nice. The point is that it makes it obvious to the reader why this method exists.
* Added PassThroughIterator * Added test for PassThroughIterator * Add @OVERRIDES and appease mypy. * Appease pylint and mypy. * Added new iterator to docs (I think...) * Opted for simplified implementation * Appease pylint * Typo * Added back in overrides decorator
This PR adds a new
PassThroughIterator
which tensorizesInstances
one at a time, and returns them in the exact order that they are created in theDatasetReader
. It generalizes theLanguageModelIterator
written by @nelson-liu in #2414, which is specifically designed for the task of contiguous language modeling. Since it seems like this approach is generally useful for problems which apply stateful models to encode long sequences (see discussion #2828), and #2414 is currently blocked by the tangential (and rather thorny) issue #2373, I think it is worth adding this seperately.The only non-trivial aspect of this iterator is that it needs to remove the batch dimension introduced when calling
Batch.as_tensor_dict()
(since this iterator is intended to be used in situations where batching performed ahead of time within theDatasetReader
). To do this, I've written a function which recursively squeezes the first dimension of tensors in aTensorDict
(see here). While I think the function behaves sensibly on tensors and dictionaries, I am not so sure about non-tensor fields likeMetadataField
orProductionRuleField
. Right now I am just returning them exactly as is, but maybe if they are singleton lists it would be better to return the single element?