-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Fix for loading IterableDatasets with undefined length #4028
Conversation
try: | ||
batches_per_epoch = len(data_loader) | ||
except TypeError: | ||
# If the dataset is lazy, it won't have a length. | ||
batches_per_epoch = None | ||
batches_per_epoch = len(data_loader) # returns "1" instead of TypeError for _LazyInstances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TypeError
exception here was never actually thrown since the __len__()
method is overridden in AllenNLP's _LazyInstances
and always returns 1
.
In case someone somehow passes an instance of plain IterableDataset
, it is better to actually let it throw TypeError then pass a custom None
to the scheduler which it cannot handle it anyway (it accepts int
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your analysis of the behavior here, but does having a length of 1 make sense for the learning rate schedulers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you ask me, having the length on 1
does not make sense in all cases apart from the case where we have one batch per epoch.
Here is our doc for the method:
allennlp/allennlp/data/dataset_readers/dataset_reader.py
Lines 93 to 100 in d66db44
def __len__(self): | |
""" | |
We rely in a couple of places that calling len on the dataloader | |
(which in turn calls len on the dataset) doesn't raise an error. | |
In the case that you have an IterableDataset and you call len, the pytorch dataloader | |
actually spits out a warning - but we need actually calling it to not crash. | |
""" | |
return 1 |
I would suggest to actually do crash instead and make clients of the method handle the exception in the proper way.
Regarding Schedulers, the only one which asks for this parameter currently is SlantedTriangular
. I didn't look closely, but it seems it has some proper response to cases like this:
actual_num_steps_per_epoch = max(self.num_steps_per_epoch, self.last_batch_num_total) |
but it might be by chance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That line is not "by chance", but it is a bit of defensive coding that should not be relied on. I believe if actual_num_steps_per_epoch
ends up smaller than the real number of steps, your learning rate would go negative. I wanted to avoid that with this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note.
That line is not "by chance", but it is a bit of defensive coding
Turns out we discovered that unfortunately, IterableDatasets with undefined length breach the defense :)
Anyway, it existed before and this PR does not change the interaction between _LazyInstances and the scheduler, so we could have a separate issue.
try: | ||
batches_per_epoch = len(data_loader) | ||
except TypeError: | ||
# If the dataset is lazy, it won't have a length. | ||
batches_per_epoch = None | ||
batches_per_epoch = len(data_loader) # returns "1" instead of TypeError for _LazyInstances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your analysis of the behavior here, but does having a length of 1 make sense for the learning rate schedulers?
@@ -76,19 +76,24 @@ def __init__( | |||
worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, | |||
) | |||
self._batches_per_epoch = batches_per_epoch or super().__len__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about why this is different. The only thing I can come up with is that there's something going on in the super class that makes it so that querying the length here doesn't work. Can you give me some pointers on why this is fixing the issue?
Also, it'd be nice to add a test with lazy datasets that fails before this change and passes after the change, just to be sure we've actually captured the issue, and don't have a regression later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And @JohnGiorgi, can you confirm that this code solves your issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matt-gardner Looks like it does. Training now takes the expected amount of time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matt-gardner, in short, the output of __len__
method in case of _Lazy_Instances
is not meaningful while self._batches_per_epoch
has to be meaningful in order for the code to work.
Specifically, the problem with this line was that in the case of batches_per_epoch is None
and _LazyInstances
dataset, self._batches_per_epoch
would be equal to 1
. (super().__len__()
would return its overridden value of 1
).
And if self._batches_per_epoch
is 1
then the code yields one batch per epoch which was the reason @JohnGiorgi had so short epochs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, got it, thanks for the additional explanation. To paraphrase, the problem is that __iter__
will have the wrong behavior when you combine lazy datasets and no _batches_per_epoch
. __len__
will still do the same thing, but __iter__
only gives you one batch per epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @maksym-del! It would still be really nice to have a simple test, to make sure we don't have a regression later.
@@ -76,19 +76,24 @@ def __init__( | |||
worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, | |||
) | |||
self._batches_per_epoch = batches_per_epoch or super().__len__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, got it, thanks for the additional explanation. To paraphrase, the problem is that __iter__
will have the wrong behavior when you combine lazy datasets and no _batches_per_epoch
. __len__
will still do the same thing, but __iter__
only gives you one batch per epoch.
If this is true, I think the correct behavior is this:
Did I capture this correctly? |
@dirkgr yes, it is also how I see it regarding SlantedTriangularScheduler (it is the only Scheduler that seems to use this parameter). I created a separate issue to track this: #4035 |
No description provided.