Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix should validate callback train end (#5542)
Browse files Browse the repository at this point in the history
* Add on_end method to ShouldValidateCallback

* Add test for _should_validate_this_epoch at end of training

* Update changelog

* Use epoch argument if provided
  • Loading branch information
JohnGiorgi authored Jan 27, 2022
1 parent 2cdb874 commit 2deacfe
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed a spurious error message "'torch.cuda' has no attribute '_check_driver'" that would be appear in the logs
when a `ConfigurationError` for missing GPU was raised.
- Load model on CPU post training to save GPU memory.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring after the first epoch regardless of `validation_start` value.
- Fixed a bug in `ShouldValidateCallback` that leads to valuation occuring every `validation_interval + 1` epochs, instead of every `validation_interval` epochs.
- Fixed a bug in `ShouldValidateCallback` that leads to validation occuring after the first epoch regardless of `validation_start` value.
- Fixed a bug in `ShouldValidateCallback` that leads to validation occuring every `validation_interval + 1` epochs, instead of every `validation_interval` epochs.
- Fixed a bug in `ShouldValidateCallback` that leads to validation never occuring at the end of training.

### Removed

Expand Down
11 changes: 11 additions & 0 deletions allennlp/training/callbacks/should_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def on_epoch(
) -> None:
trainer._should_validate_this_epoch = self._should_validate(epoch=epoch + 1)

def on_end(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any] = None,
epoch: int = None,
is_primary: bool = True,
**kwargs,
) -> None:
epoch = epoch + 1 if epoch is not None else trainer._epochs_completed
trainer._should_validate_this_epoch = self._should_validate(epoch=epoch)

def _should_validate(self, epoch: int) -> bool:
should_validate = True
if self._validation_start is not None and epoch < self._validation_start:
Expand Down
6 changes: 5 additions & 1 deletion tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,11 @@ def test_should_validate_callback(self):
assert not trainer._should_validate_this_epoch

# Satisfies both 'validation_start' and 'validation_interval'
callback.on_epoch(trainer, metrics={}, epoch=5)
callback.on_epoch(trainer, metrics={}, epoch=3)
assert trainer._should_validate_this_epoch

# Check that final validation happens on the last epoch
callback.on_end(trainer)
assert trainer._should_validate_this_epoch


Expand Down

0 comments on commit 2deacfe

Please sign in to comment.