Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add ERROR callback event (#2983)
Browse files Browse the repository at this point in the history
* add TRAINING_FAIL event

* address comments

* remove unnecessary diff

* oops, revert my messing around

* been a long day

* use decorator + add test

* replace decorator with sys.excepthook

* another attempt

* pylint

* pylint

* pylint and mypy

* simplify

* remove unused EVENTS
  • Loading branch information
epwalsh authored and schmmd committed Sep 6, 2019
1 parent ce50407 commit 7cfaab4
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
50 changes: 48 additions & 2 deletions allennlp/tests/training/callback_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
import time
from typing import Dict, Iterable
from typing import Dict, Iterable, Optional

import torch
import responses
Expand All @@ -26,7 +26,7 @@
from allennlp.models.model import Model
from allennlp.training.callback_trainer import CallbackTrainer
from allennlp.training.callbacks import (
Events,
Events, Callback, handle_event,
LogToTensorboard, Checkpoint, Validate, PostToUrl, GradientNormAndClip,
UpdateLearningRate, UpdateMomentum, TrackMetrics, UpdateMovingAverage
)
Expand Down Expand Up @@ -912,3 +912,49 @@ def test_restored_training_returns_best_epoch_metrics_even_if_no_better_epoch_is
assert training_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"]
assert training_metrics["best_epoch"] == 0
assert training_metrics["validation_loss"] > restored_metrics["validation_loss"]

def test_handle_errors(self):
# pylint: disable=unused-argument,no-self-use
class ErrorTest(Callback):
"""
A callback with three triggers
* at BATCH_START, it raises a RuntimeError
* at TRAINING_END, it sets a finished flag to True
* at ERROR, it captures `trainer.exception`
"""
def __init__(self) -> None:
self.exc: Optional[Exception] = None
self.finished_training = None

@handle_event(Events.BATCH_START)
def raise_exception(self, trainer):
raise RuntimeError("problem starting batch")

@handle_event(Events.TRAINING_END)
def finish_training(self, trainer):
self.finished_training = True

@handle_event(Events.ERROR)
def capture_error(self, trainer):
self.exc = trainer.exception

error_test = ErrorTest()
callbacks = self.default_callbacks() + [error_test]

original_trainer = CallbackTrainer(self.model,
self.instances,
self.iterator,
self.optimizer,
callbacks=callbacks,
num_epochs=1, serialization_dir=self.TEST_DIR)

with pytest.raises(RuntimeError):
# pylint: disable=not-callable
original_trainer.train()

# The callback should have captured the exception.
assert error_test.exc is not None
assert error_test.exc.args == ("problem starting batch",)

# The "finished" flag should never have been set to True.
assert not error_test.finished_training
19 changes: 17 additions & 2 deletions allennlp/training/callback_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import time
import datetime
import functools
import math
from typing import Dict, Optional, List, Union, Any, Iterable
import torch
Expand All @@ -27,6 +28,18 @@

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

def handle_errors(method):
@functools.wraps(method)
def train_and_handle_errors(self: 'CallbackTrainer') -> Dict[str, Any]:
try:
return method(self)
except Exception as exc:
self.exception = exc
self.handler.fire_event(Events.ERROR)
raise

return train_and_handle_errors


@TrainerBase.register("callback")
class CallbackTrainer(TrainerBase):
Expand Down Expand Up @@ -122,6 +135,9 @@ def __init__(self,
self.shuffle = shuffle
self.handler = CallbackHandler(callbacks, self)

# For capturing errors that occur during the train loop.
self.exception: Optional[Exception] = None

def generate_training_batches(self):
"""
Generates one epoch worth of training data. Stores it in trainer instance variables
Expand Down Expand Up @@ -196,7 +212,6 @@ def train_one_batch_group(self, batch_group: List[TensorDict]) -> str:

return training_util.description_from_metrics(self.train_metrics)


def train_one_epoch(self) -> None:
"""
Trains the model for a single epoch.
Expand All @@ -223,7 +238,7 @@ def train_one_epoch(self) -> None:
self.handler.fire_event(Events.VALIDATE)
self.handler.fire_event(Events.EPOCH_END)


@handle_errors
def train(self) -> Dict[str, Any]:
"""
Trains the supplied model with the supplied parameters.
Expand Down
3 changes: 3 additions & 0 deletions allennlp/training/callbacks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class Events:
BATCH_START = "BATCH_START"

FORWARD = "FORWARD"

BACKWARD = "BACKWARD"

BATCH_END = "BATCH_END"
Expand All @@ -15,3 +16,5 @@ class Events:
EPOCH_END = "EPOCH_END"

TRAINING_END = "TRAINING_END"

ERROR = "ERROR"

0 comments on commit 7cfaab4

Please sign in to comment.