Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix ShouldValidateCallback (#5536)
Browse files Browse the repository at this point in the history
* Fix bug in should validate callback

* Add test for should validate callback

* Update changelog

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
JohnGiorgi and dirkgr authored Jan 12, 2022
1 parent b0b3ad4 commit a711703
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed a spurious error message "'torch.cuda' has no attribute '_check_driver'" that would be appear in the logs
when a `ConfigurationError` for missing GPU was raised.
- Load model on CPU post training to save GPU memory.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring after the first epoch regardless of `validation_start` value.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring every `validation_interval + 1` epochs, instead of every `validation_interval` epochs.

### Removed

Expand Down
16 changes: 12 additions & 4 deletions allennlp/training/callbacks/should_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def __init__(
self._validation_start = validation_start
self._validation_interval = validation_interval

def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
trainer._should_validate_this_epoch = self._should_validate(epoch=0)

def on_epoch(
self,
trainer: "GradientDescentTrainer",
Expand All @@ -33,9 +38,12 @@ def on_epoch(
is_primary: bool = True,
**kwargs,
) -> None:
trainer._should_validate_this_epoch = self._should_validate(epoch=epoch + 1)

def _should_validate(self, epoch: int) -> bool:
should_validate = True
if self._validation_start is not None and epoch < self._validation_start:
trainer._should_validate_this_epoch = False
should_validate = False
elif self._validation_interval is not None and epoch % self._validation_interval != 0:
trainer._should_validate_this_epoch = False
else:
trainer._should_validate_this_epoch = True
should_validate = False
return should_validate
10 changes: 7 additions & 3 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,16 +1361,20 @@ def test_should_validate_callback(self):
)
trainer.train()

# Doesn't satisfy 'validation_start' or 'validation_interval'
# Shouldn't validate on the first epoch as it's before the 'validation_start'
callback.on_start(trainer)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_interval' but not 'validation_start'
callback.on_epoch(trainer, metrics={}, epoch=1)
assert not trainer._should_validate_this_epoch

# Satisfies 'validation_start' but not 'validation_interval'
# Doesn't satisfy 'validation_start' or 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=2)
assert not trainer._should_validate_this_epoch

# Satisfies both 'validation_start' and 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=4)
callback.on_epoch(trainer, metrics={}, epoch=5)
assert trainer._should_validate_this_epoch


Expand Down

0 comments on commit a711703

Please sign in to comment.