Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Fix for loading IterableDatasets with undefined length #4028

Merged
merged 7 commits into from
Apr 7, 2020
Merged

Fix for loading IterableDatasets with undefined length #4028

merged 7 commits into from
Apr 7, 2020

Conversation

MaksymDel
Copy link
Contributor

No description provided.

try:
batches_per_epoch = len(data_loader)
except TypeError:
# If the dataset is lazy, it won't have a length.
batches_per_epoch = None
batches_per_epoch = len(data_loader) # returns "1" instead of TypeError for _LazyInstances
Copy link
Contributor Author

@MaksymDel MaksymDel Apr 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TypeError exception here was never actually thrown since the __len__() method is overridden in AllenNLP's _LazyInstances and always returns 1.
In case someone somehow passes an instance of plain IterableDataset, it is better to actually let it throw TypeError then pass a custom None to the scheduler which it cannot handle it anyway (it accepts int).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your analysis of the behavior here, but does having a length of 1 make sense for the learning rate schedulers?

Copy link
Contributor Author

@MaksymDel MaksymDel Apr 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you ask me, having the length on 1 does not make sense in all cases apart from the case where we have one batch per epoch.

Here is our doc for the method:

def __len__(self):
"""
We rely in a couple of places that calling len on the dataloader
(which in turn calls len on the dataset) doesn't raise an error.
In the case that you have an IterableDataset and you call len, the pytorch dataloader
actually spits out a warning - but we need actually calling it to not crash.
"""
return 1

I would suggest to actually do crash instead and make clients of the method handle the exception in the proper way.


Regarding Schedulers, the only one which asks for this parameter currently is SlantedTriangular. I didn't look closely, but it seems it has some proper response to cases like this:

actual_num_steps_per_epoch = max(self.num_steps_per_epoch, self.last_batch_num_total)

but it might be by chance.

Copy link
Member

@dirkgr dirkgr Apr 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That line is not "by chance", but it is a bit of defensive coding that should not be relied on. I believe if actual_num_steps_per_epoch ends up smaller than the real number of steps, your learning rate would go negative. I wanted to avoid that with this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note.

That line is not "by chance", but it is a bit of defensive coding

Turns out we discovered that unfortunately, IterableDatasets with undefined length breach the defense :)

Anyway, it existed before and this PR does not change the interaction between _LazyInstances and the scheduler, so we could have a separate issue.

try:
batches_per_epoch = len(data_loader)
except TypeError:
# If the dataset is lazy, it won't have a length.
batches_per_epoch = None
batches_per_epoch = len(data_loader) # returns "1" instead of TypeError for _LazyInstances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your analysis of the behavior here, but does having a length of 1 make sense for the learning rate schedulers?

@@ -76,19 +76,24 @@ def __init__(
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
)
self._batches_per_epoch = batches_per_epoch or super().__len__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about why this is different. The only thing I can come up with is that there's something going on in the super class that makes it so that querying the length here doesn't work. Can you give me some pointers on why this is fixing the issue?

Also, it'd be nice to add a test with lazy datasets that fails before this change and passes after the change, just to be sure we've actually captured the issue, and don't have a regression later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And @JohnGiorgi, can you confirm that this code solves your issue?

Copy link
Contributor

@JohnGiorgi JohnGiorgi Apr 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matt-gardner Looks like it does. Training now takes the expected amount of time.

Copy link
Contributor Author

@MaksymDel MaksymDel Apr 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matt-gardner, in short, the output of __len__ method in case of _Lazy_Instances is not meaningful while self._batches_per_epoch has to be meaningful in order for the code to work.

Specifically, the problem with this line was that in the case of batches_per_epoch is None and _LazyInstances dataset, self._batches_per_epoch would be equal to 1. (super().__len__() would return its overridden value of 1).
And if self._batches_per_epoch is 1 then the code yields one batch per epoch which was the reason @JohnGiorgi had so short epochs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it, thanks for the additional explanation. To paraphrase, the problem is that __iter__ will have the wrong behavior when you combine lazy datasets and no _batches_per_epoch. __len__ will still do the same thing, but __iter__ only gives you one batch per epoch.

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @maksym-del! It would still be really nice to have a simple test, to make sure we don't have a regression later.

@@ -76,19 +76,24 @@ def __init__(
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
)
self._batches_per_epoch = batches_per_epoch or super().__len__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it, thanks for the additional explanation. To paraphrase, the problem is that __iter__ will have the wrong behavior when you combine lazy datasets and no _batches_per_epoch. __len__ will still do the same thing, but __iter__ only gives you one batch per epoch.

@MaksymDel MaksymDel requested a review from matt-gardner April 6, 2020 20:33
@dirkgr
Copy link
Member

dirkgr commented Apr 7, 2020

batches_per_epoch is only used for SlantedTriangular, right? There is no other place that needs that number?

If this is true, I think the correct behavior is this:

  • Eager dataset in all cases: No problem anyways
  • Lazy dataset, not using SlantedTriangular: Runs as normal. Does not print an accurate ETA during training, unless you specify batches_per_epoch in the config.
  • Lazy dataset, using SlantedTriangular: Only works if you specify batches_per_epoch in the config. Crashes otherwise.

Did I capture this correctly?

@MaksymDel
Copy link
Contributor Author

MaksymDel commented Apr 7, 2020

@dirkgr yes, it is also how I see it regarding SlantedTriangularScheduler (it is the only Scheduler that seems to use this parameter).
However, each time we call len(dataset) or len(dataloader) we should expect that the result could be 1 because the dataset could be lazy and have epoch size unspecified. So we should inspect all existing code in allennlp which asks for the length of these objects I believe. And also new data-related code should account for this (that is why it is better to throw an exception when the len is not overridden properly but called IMO).

I created a separate issue to track this: #4035
This PR's goal was to solve a different problem (which it seems it accomplished) while #4035 existed before (we just did not run into it) so I propose to merge it and follow up later.

@dirkgr dirkgr merged commit 0018ff8 into allenai:master Apr 7, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants