From 7cfaab475ccb39d6d59a7c9ce1d7c4d85cb5ae5c Mon Sep 17 00:00:00 2001 From: Evan Pete Walsh Date: Fri, 6 Sep 2019 12:13:45 -0700 Subject: [PATCH] Add ERROR callback event (#2983) * add TRAINING_FAIL event * address comments * remove unnecessary diff * oops, revert my messing around * been a long day * use decorator + add test * replace decorator with sys.excepthook * another attempt * pylint * pylint * pylint and mypy * simplify * remove unused EVENTS --- .../tests/training/callback_trainer_test.py | 50 ++++++++++++++++++- allennlp/training/callback_trainer.py | 19 ++++++- allennlp/training/callbacks/events.py | 3 ++ 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/allennlp/tests/training/callback_trainer_test.py b/allennlp/tests/training/callback_trainer_test.py index a7e8269b034..a5d8db93b7a 100644 --- a/allennlp/tests/training/callback_trainer_test.py +++ b/allennlp/tests/training/callback_trainer_test.py @@ -5,7 +5,7 @@ import os import re import time -from typing import Dict, Iterable +from typing import Dict, Iterable, Optional import torch import responses @@ -26,7 +26,7 @@ from allennlp.models.model import Model from allennlp.training.callback_trainer import CallbackTrainer from allennlp.training.callbacks import ( - Events, + Events, Callback, handle_event, LogToTensorboard, Checkpoint, Validate, PostToUrl, GradientNormAndClip, UpdateLearningRate, UpdateMomentum, TrackMetrics, UpdateMovingAverage ) @@ -912,3 +912,49 @@ def test_restored_training_returns_best_epoch_metrics_even_if_no_better_epoch_is assert training_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"] assert training_metrics["best_epoch"] == 0 assert training_metrics["validation_loss"] > restored_metrics["validation_loss"] + + def test_handle_errors(self): + # pylint: disable=unused-argument,no-self-use + class ErrorTest(Callback): + """ + A callback with three triggers + * at BATCH_START, it raises a RuntimeError + * at TRAINING_END, it sets a finished flag to True + * at ERROR, it captures `trainer.exception` + """ + def __init__(self) -> None: + self.exc: Optional[Exception] = None + self.finished_training = None + + @handle_event(Events.BATCH_START) + def raise_exception(self, trainer): + raise RuntimeError("problem starting batch") + + @handle_event(Events.TRAINING_END) + def finish_training(self, trainer): + self.finished_training = True + + @handle_event(Events.ERROR) + def capture_error(self, trainer): + self.exc = trainer.exception + + error_test = ErrorTest() + callbacks = self.default_callbacks() + [error_test] + + original_trainer = CallbackTrainer(self.model, + self.instances, + self.iterator, + self.optimizer, + callbacks=callbacks, + num_epochs=1, serialization_dir=self.TEST_DIR) + + with pytest.raises(RuntimeError): + # pylint: disable=not-callable + original_trainer.train() + + # The callback should have captured the exception. + assert error_test.exc is not None + assert error_test.exc.args == ("problem starting batch",) + + # The "finished" flag should never have been set to True. + assert not error_test.finished_training diff --git a/allennlp/training/callback_trainer.py b/allennlp/training/callback_trainer.py index 68e03872ecc..49797f10c77 100644 --- a/allennlp/training/callback_trainer.py +++ b/allennlp/training/callback_trainer.py @@ -5,6 +5,7 @@ import logging import time import datetime +import functools import math from typing import Dict, Optional, List, Union, Any, Iterable import torch @@ -27,6 +28,18 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name +def handle_errors(method): + @functools.wraps(method) + def train_and_handle_errors(self: 'CallbackTrainer') -> Dict[str, Any]: + try: + return method(self) + except Exception as exc: + self.exception = exc + self.handler.fire_event(Events.ERROR) + raise + + return train_and_handle_errors + @TrainerBase.register("callback") class CallbackTrainer(TrainerBase): @@ -122,6 +135,9 @@ def __init__(self, self.shuffle = shuffle self.handler = CallbackHandler(callbacks, self) + # For capturing errors that occur during the train loop. + self.exception: Optional[Exception] = None + def generate_training_batches(self): """ Generates one epoch worth of training data. Stores it in trainer instance variables @@ -196,7 +212,6 @@ def train_one_batch_group(self, batch_group: List[TensorDict]) -> str: return training_util.description_from_metrics(self.train_metrics) - def train_one_epoch(self) -> None: """ Trains the model for a single epoch. @@ -223,7 +238,7 @@ def train_one_epoch(self) -> None: self.handler.fire_event(Events.VALIDATE) self.handler.fire_event(Events.EPOCH_END) - + @handle_errors def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. diff --git a/allennlp/training/callbacks/events.py b/allennlp/training/callbacks/events.py index a42cd0109db..293f0c7e058 100644 --- a/allennlp/training/callbacks/events.py +++ b/allennlp/training/callbacks/events.py @@ -6,6 +6,7 @@ class Events: BATCH_START = "BATCH_START" FORWARD = "FORWARD" + BACKWARD = "BACKWARD" BATCH_END = "BATCH_END" @@ -15,3 +16,5 @@ class Events: EPOCH_END = "EPOCH_END" TRAINING_END = "TRAINING_END" + + ERROR = "ERROR"