Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Fix bug with lazy data loading, un-implement __len__ on AllennlpLazyDataset #4328

Merged
merged 6 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_text_field_mask()` now supports padding indices that are not `0`.
- A bug where `predictor.get_gradients()` would return an empty dictionary if an embedding layer had trainable set to false
- Fixes `PretrainedTransformerMismatchedIndexer` in the case where a token consists of zero word pieces.
- Fixes a bug when using a lazy dataset reader that results in a `UserWarning` from PyTorch being printed at
every iteration during training.

### Added

Expand All @@ -38,6 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- SimpleTagger will no longer calculate span-based F1 metric when `calculate_span_f1` is `False`.
- CPU memory for every worker is now reported in the logs and the metrics. Previously this was only reporting the CPU memory of the master process, and so it was only
correct in the non-distributed setting.
- To be consistent with PyTorch `IterableDataset`, `AllennlpLazyDataset` no longer implements `__len__()`.
Previously it would always return 1.

## [v1.0.0rc5](/~https://github.com/allenai/allennlp/releases/tag/v1.0.0rc5) - 2020-05-26

Expand Down
2 changes: 1 addition & 1 deletion allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def run(self) -> Dict[str, Any]:
return self.trainer.train()

def finish(self, metrics: Dict[str, Any]):
if self.evaluation_data_loader and self.evaluate_on_test:
if self.evaluation_data_loader is not None and self.evaluate_on_test:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this so len() is not called.

logger.info("The model will be evaluated using the best epoch weights.")
test_metrics = training_util.evaluate(
self.model,
Expand Down
15 changes: 0 additions & 15 deletions allennlp/data/dataset_readers/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ def __init__(
def index_with(self, vocab: Vocabulary):
self.vocab = vocab

def __len__(self):
"""
We rely in a couple of places that calling len on the dataloader
(which in turn calls len on the dataset) doesn't raise an error.
In the case that you have an IterableDataset and you call len, the pytorch dataloader
actually spits out a warning - but we need actually calling it to not crash.
"""
return 1


class _LazyInstances(AllennlpLazyDataset):
"""
Expand Down Expand Up @@ -122,9 +113,6 @@ def __iter__(self) -> Iterator[Instance]:
def index_with(self, vocab: Vocabulary):
self.inner.index_with(vocab)

def __len__(self):
return len(self.inner)


class _DistributedLazyInstances(AllennlpLazyDataset):
def __init__(self, inner: AllennlpLazyDataset) -> None:
Expand All @@ -144,9 +132,6 @@ def __iter__(self) -> Iterator[Instance]:
def index_with(self, vocab: Vocabulary):
self.inner.index_with(vocab)

def __len__(self):
return len(self.inner)


class DatasetReader(Registrable):
"""
Expand Down
3 changes: 1 addition & 2 deletions allennlp/modules/token_embedders/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,10 @@ def __next__(self) -> str:
return next(self._iterator)

def __len__(self) -> Optional[int]:
""" Hack for tqdm: no need for explicitly passing `total=file.num_tokens` """
if self.num_tokens:
return self.num_tokens
raise AttributeError(
'an object of type EmbeddingsTextFile has "len()" only if the underlying '
"an object of type EmbeddingsTextFile implements `__len__` only if the underlying "
"text file declares the number of tokens (i.e. the number of lines following)"
"in the first line. That is not the case of this particular instance."
)
Expand Down
10 changes: 6 additions & 4 deletions allennlp/training/learning_rate_schedulers/slanted_triangular.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Optional

from overrides import overrides
import torch
Expand Down Expand Up @@ -33,7 +33,7 @@ class SlantedTriangular(LearningRateScheduler):
This argument does not get an entry in a configuration file for the object.
num_epochs : `int`, required.
The total number of epochs for which the model should be trained.
num_steps_per_epoch : `int`, required.
num_steps_per_epoch : `Optional[int]`, optional (default = `None`)
The number of steps (updates, batches) per training epoch.
cut_frac : `float`, optional (default = `0.1`).
The fraction of the steps to increase the learning rate.
Expand All @@ -53,7 +53,7 @@ def __init__(
self,
optimizer: torch.optim.Optimizer,
num_epochs: int,
num_steps_per_epoch: int,
num_steps_per_epoch: Optional[int] = None,
cut_frac: float = 0.1,
ratio: int = 32,
last_epoch: int = -1,
Expand Down Expand Up @@ -138,7 +138,9 @@ def get_values(self):
self.batch_num_total_epoch_end[-1] / (len(self.batch_num_total_epoch_end) - 1)
)
else:
actual_num_steps_per_epoch = max(self.num_steps_per_epoch, self.last_batch_num_total)
actual_num_steps_per_epoch = max(
self.num_steps_per_epoch or 1, self.last_batch_num_total
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it default to 1 here? Isn't that a problem when it happens?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives the same behavior as it was before. I'm not sure if that's an issue.

)

if self.freezing_current:
# if we still freeze, we restrict the schedule to the current epoch
Expand Down
22 changes: 16 additions & 6 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(
self.optimizer = optimizer

if patience is None: # no early stopping
if validation_data_loader:
if validation_data_loader is not None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this so len() is not called.

logger.warning(
"You provided a validation dataset but patience was set to None, "
"meaning that early stopping is disabled"
Expand Down Expand Up @@ -508,9 +508,15 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

logger.info("Training")

num_training_batches = math.ceil(
len(self.data_loader) / self._num_gradient_accumulation_steps
)
num_training_batches: Union[int, float]
try:
len_data_loader = len(self.data_loader)
num_training_batches = math.ceil(
len_data_loader / self._num_gradient_accumulation_steps
)
except TypeError:
num_training_batches = float("inf")

# Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's
# progress is shown
if self._master:
Expand Down Expand Up @@ -1061,8 +1067,12 @@ def from_partial_objects(
if not optimizer_:
optimizer_ = Optimizer.default(parameters)

batches_per_epoch = len(data_loader) # returns "1" instead of TypeError for _LazyInstances
batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps)
batches_per_epoch: Optional[int]
try:
batches_per_epoch = len(data_loader)
batches_per_epoch = math.ceil(batches_per_epoch / num_gradient_accumulation_steps)
except TypeError:
batches_per_epoch = None

moving_average_ = moving_average.construct(parameters=parameters)
learning_rate_scheduler_ = learning_rate_scheduler.construct(
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def evaluate(

iterator = iter(data_loader)
logger.info("Iterating over dataset")
generator_tqdm = Tqdm.tqdm(iterator, total=len(data_loader))
generator_tqdm = Tqdm.tqdm(iterator)

# Number of batches in instances.
batch_count = 0
Expand Down