From 03e059fd31301d2fdf1af527e42a9fa9bb177097 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Fri, 15 Mar 2019 22:16:42 -0700 Subject: [PATCH] [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests --- python/mxnet/gluon/estimator/__init__.py | 21 ++ python/mxnet/gluon/estimator/estimator.py | 267 +++++++++++++++ python/mxnet/gluon/estimator/event_handler.py | 307 ++++++++++++++++++ .../unittest/test_gluon_event_handler.py | 92 ++++++ 4 files changed, 687 insertions(+) create mode 100644 python/mxnet/gluon/estimator/__init__.py create mode 100644 python/mxnet/gluon/estimator/estimator.py create mode 100644 python/mxnet/gluon/estimator/event_handler.py create mode 100644 tests/python/unittest/test_gluon_event_handler.py diff --git a/python/mxnet/gluon/estimator/__init__.py b/python/mxnet/gluon/estimator/__init__.py new file mode 100644 index 000000000000..58600dadffb4 --- /dev/null +++ b/python/mxnet/gluon/estimator/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""Gluon Estimator Module""" +from .estimator import * +from .event_handler import * diff --git a/python/mxnet/gluon/estimator/estimator.py b/python/mxnet/gluon/estimator/estimator.py new file mode 100644 index 000000000000..159f7e220427 --- /dev/null +++ b/python/mxnet/gluon/estimator/estimator.py @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gluon Estimator""" + +import warnings + +from .event_handler import LoggingHandler +from ... import gluon, autograd +from ...context import Context, cpu, gpu, num_gpus +from ...io import DataIter +from ...metric import EvalMetric, Loss + +__all__ = ['Estimator'] + + +class Estimator(object): + """Estimator Class for easy model training + + :py:class:`Estimator` can be used to facilitate the training & validation process + + + Parameters + ---------- + loss : Loss or list of Loss + Loss(objective functions) to calculate during training + metrics : EvalMetric or list of EvalMetric + Metrics for evaluating models + initializer : Initializer + initializer to initialize the network + trainers : Trainer or list of Trainer + Trainers to apply optimizers on network parameters + context : Context or list of Context + devices to run the training on + """ + + def __init__(self, net, + loss=None, + metrics=None, + initializer=None, + trainers=None, + context=None): + + self.net = net + self.stop_training = False + + if isinstance(loss, gluon.loss.Loss): + self.loss = [loss] + else: + self.loss = loss or [] + for l in self.loss: + if not isinstance(loss, gluon.loss.Loss): + raise ValueError("loss must be a Loss or a list of Loss, refer to gluon.loss.Loss") + + if isinstance(metrics, EvalMetric): + self.metrics = [metrics] + else: + self.metrics = metrics or [] + for metric in self.metrics: + if not isinstance(metric, EvalMetric): + raise ValueError("metrics must be a Metric or a list of Metric, refer to mxnet.metric.EvalMetric") + + self.initializer = initializer + # store training statistics + self.train_stats = {} + self.train_stats['epochs'] = [] + self.train_stats['learning_rate'] = [] + # current step of the epoch + self.train_stats['step'] = '' + for metric in self.metrics: + # record a history of metrics over each epoch + self.train_stats['train_' + metric.name] = [] + # only record the latest metric numbers after each batch + self.train_stats['batch_' + metric.name] = 0. + self.loss_metrics = [] + # using the metric wrapper for loss to record loss value + for l in self.loss: + self.loss_metrics.append(Loss(l.name)) + self.train_stats['train_' + l.name] = [] + # only record the latest loss numbers after each batch + self.train_stats['batch_' + l.name] = 0. + + # handle context + if isinstance(context, Context): + self.context = [context] + if not context: + if num_gpus() > 0: + # only use 1 GPU by default + if num_gpus() > 1: + warnings.warn("You have multiple GPUs, gpu(0) will be used by default." + "To utilize all your GPUs, specify context as a list of gpus, " + "e.g. context=[mx.gpu(0), mx.gpu(1)] ") + self.context = [gpu(0)] + else: + self.context = [cpu()] + + # initialize the network + if self.initializer: + if self._is_initialized(): + # if already initialized, re-init with user specified initializer + warnings.warn("Network already initialized, re-initializing with %s. " + "You don't need to pass initializer if you already " + "initialized your net."% type(self.initializer).__name__) + self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) + else: + # initialize with user specified initializer + self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=False) + else: + if not self._is_initialized(): + self.net.initialize(ctx=self.context) + + # handle trainers + if isinstance(trainers, gluon.Trainer): + self.trainers = [trainers] + else: + self.trainers = trainers or [] + if not self.trainers: + warnings.warn("No trainer specified, default SGD optimizer " + "with learning rate 0.001 is used.") + self.trainers = [gluon.Trainer(self.net.collect_params(), + 'sgd', {'learning_rate': 0.001})] + + def _is_initialized(self): + param_dict = self.net.collect_params() + for param in param_dict: + try: + param_dict[param].list_ctx() + except RuntimeError: + return False + return True + + def _batch_fn(self, batch, ctx, is_iterator=False): + if is_iterator: + data = batch.data[0] + label = batch.label[0] + else: + data = batch[0] + label = batch[1] + data = gluon.utils.split_and_load(data, ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0) + return data, label + + def fit(self, train_data, + epochs=1, + batch_size=None, + event_handlers=None, + batch_fn=None): + """Main training loop + + Parameters + ---------- + train_data : DataLoader or DataIter + training data with data and labels + val_data : DataLoader or DataIter + validation data with data and labels + epochs : int, default 1 + number of epochs to iterate on the training data. + batch_size : int + number of samples per gradient update. + default will be 32 per device + event_handlers : EventHandler or list of EventHandler + list of EventHandlers to apply during training + batch_fn : function + custom batch function to extract data and label + from a data batch and load into contexts(devices) + """ + + + self.epochs = epochs + if not batch_size: + batch_size = 32 * len(self.context) + + event_handlers = event_handlers or [] + # provide default logging handler + if not event_handlers or \ + not any(isinstance(handler, LoggingHandler) for handler in event_handlers): + event_handlers.append(LoggingHandler(self)) + + # training begin + for handler in event_handlers: + handler.train_begin() + + for epoch in range(epochs): + # epoch begin + self.train_stats['epochs'].append(epoch) + self.train_stats['learning_rate'].append(self.trainers[0].learning_rate) + + for handler in event_handlers: + handler.epoch_begin() + + for metric in self.metrics + self.loss_metrics: + metric.reset() + + for i, batch in enumerate(train_data): + if not batch_fn: + if isinstance(train_data, gluon.data.DataLoader): + data, label = self._batch_fn(batch, self.context) + elif isinstance(train_data, DataIter): + data, label = self._batch_fn(batch, self.context, is_iterator=True) + else: + raise ValueError("You are using a custom iteration, please also provide " + "batch_fn to extract data and label") + else: + data, label = batch_fn(batch, self.context) + + # batch begin + for handler in event_handlers: + handler.batch_begin() + + with autograd.record(): + pred = [self.net(x) for x in data] + losses = [] + for loss in self.loss: + losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) + + for loss in losses: + for l in loss: + l.backward() + + # update metrics + for metric in self.metrics: + metric.update(label, pred) + self.train_stats['batch_' + metric.name] = metric.get()[1] + for loss, loss_metric, in zip(losses, self.loss_metrics): + loss_metric.update(0, [l for l in loss]) + self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] + + try: + self.train_stats['step'] = "{}/{}".format(batch_size * (i + 1), len(train_data._dataset)) + except AttributeError: + self.train_stats['step'] = i + + for trainer in self.trainers: + trainer.step(batch_size) + + # batch end + for handler in event_handlers: + handler.batch_end() + + for metric in self.metrics + self.loss_metrics: + self.train_stats['train_' + metric.name].append(metric.get()[1]) + # epoch end + for handler in event_handlers: + handler.epoch_end() + + if self.stop_training: + break + + # train end + for handler in event_handlers: + handler.train_end() diff --git a/python/mxnet/gluon/estimator/event_handler.py b/python/mxnet/gluon/estimator/event_handler.py new file mode 100644 index 000000000000..0162c36993f3 --- /dev/null +++ b/python/mxnet/gluon/estimator/event_handler.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import +"""Gluon EventHandlers for Estimators""" + +__all__ = ['EventHandler', 'LoggingHandler'] +import logging +import os +import time +import warnings + +import numpy as np + + +class EventHandler(object): + """Basic for event handlers + + :py:class:`EventHandler` can perform user defined functions at + different stages of training: train begin, epoch begin, batch begin, + batch end, epoch end, train end. + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + """ + def __init__(self, estimator): + self._estimator = estimator + + def train_begin(self): + pass + + def train_end(self): + pass + + def batch_begin(self): + pass + + def batch_end(self): + pass + + def epoch_begin(self): + pass + + def epoch_end(self): + pass + + +class LoggingHandler(EventHandler): + """Basic Logging Handler that applies to every Gluon estimator by default. + + :py:class:`LoggingHandler` logs hyper-parameters, training statistics, + and other useful information during training + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + file_name : str + file name to save the logs + file_location: str + file location to save the logs + """ + + def __init__(self, estimator, file_name=None, file_location=None, ): + super(LoggingHandler, self).__init__(estimator) + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + stream_handler = logging.StreamHandler() + self.logger.addHandler(stream_handler) + # save logger to file only if file name or location is specified + if file_name or file_location: + file_name = file_name or 'estimator_log' + file_location = file_location or './' + file_handler = logging.FileHandler(os.path.join(file_location, file_name)) + self.logger.addHandler(file_handler) + + def train_begin(self): + pass + + def train_end(self): + pass + + def batch_begin(self): + self.batch_start = time.time() + + def batch_end(self): + batch_time = time.time() - self.batch_start + epoch = self._estimator.train_stats['epochs'][-1] + step = self._estimator.train_stats['step'] + msg = '[Epoch %d] [Step %s] time/step: %.3fs ' % (epoch, step, batch_time) + for key in self._estimator.train_stats.keys(): + if key.startswith('batch_'): + msg += key[6:] + ': ' + '%.4f ' % self._estimator.train_stats[key] + self.logger.info(msg) + + def epoch_begin(self): + self.epoch_start = time.time() + + def epoch_end(self): + epoch_time = time.time() - self.epoch_start + epoch = self._estimator.train_stats['epochs'][-1] + msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time) + for key in self._estimator.train_stats.keys(): + if key.startswith('train_') or key.startswith('test_'): + msg += key + ': ' + '%.4f ' % self._estimator.train_stats[key][epoch] + self.logger.info(msg) + + +class CheckpointHandler(EventHandler): + """Save the model after every epoch. + + :py:class:`CheckpointHandler` save the network parameters every epoch + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + filepath : str + file name to save the parameters, it can contain directories, + for example: ./saved_model/resnet.params + monitor: str + the metrics to monitor + verbose: int, default 0 + verbosity mode + save_best_only: bool + if True, only save the parameters if monitored value improved + mode: str, default 'auto' + one of {auto, min, max}, if `save_best_only=True`, the comparison to make + and determine if the monitored value has improved + period: int, default 1 + intervals between saving the network + """ + + def __init__(self, estimator, + filepath, + monitor='val_loss', + verbose=0, + save_best_only=False, + mode='auto', + period=1): + super(CheckpointHandler, self).__init__(estimator) + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + self.save_best_only = save_best_only + self.period = period + self.epochs_since_last_save = 0 + self.logger = logging.getLogger(__name__) + + if mode not in ['auto', 'min', 'max']: + warnings.warn('ModelCheckpoint mode %s is unknown, ' + 'fallback to auto mode.' % (mode), + RuntimeWarning) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + elif mode == 'max': + self.monitor_op = np.greater + self.best = -np.Inf + else: + # use greater for accuracy and less otherwise + if 'acc' in self.monitor: + self.monitor_op = np.greater + self.best = -np.Inf + else: + self.monitor_op = np.less + self.best = np.Inf + + def epoch_end(self, ): + epoch = self._estimator.train_stats['epochs'][-1] + # add extension for weights + if '.params' not in self.filepath: + self.filepath += '.params' + self.epochs_since_last_save += 1 + if self.epochs_since_last_save >= self.period: + self.epochs_since_last_save = 0 + if self.save_best_only: + # check if monitor exists in train_stats + if self.monitor not in self._estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' + 'you are passing one of the metric names as monitor', self.monitor)) + self._estimator.net.save_parameters(self.filepath) + else: + current = self._estimator.train_stats[self.monitor][-1] + if self.monitor_op(current, self.best): + if self.verbose > 0: + self.logger.info('\n[Epoch %d] %s improved from %0.5f to %0.5f,' + ' saving model to %s', + epoch, self.monitor, self.best, current, self.filepath) + self.best = current + self._estimator.net.save_parameters(self.filepath) + else: + if self.verbose > 0: + self.logger.info('\n[Epoch %d] %s did not improve from %0.5f, skipping save model', + epoch, self.monitor, self.best) + else: + if self.verbose > 0: + logging.info('\nEpoch %d: saving model to %s', epoch, self.filepath) + self._estimator.net.save_parameters(self.filepath) + + +class EarlyStoppingHandler(EventHandler): + """Early stop training if monitored value is not improving + + Parameters + ---------- + estimator : Estimator + The :py:class:`Estimator` to get training statistics + monitor: str + the metrics to monitor + min_delta: float, default 0 + minimal change in monitored value to be considered as an improvement + patience: int, default 0 + number of epochs to wait for improvement before terminate training + mode: str, default 'auto' + one of {auto, min, max}, the comparison to make + and determine if the monitored value has improved + baseline: float + baseline value to compare the monitored value with + """ + + def __init__(self, estimator, + monitor='val_loss', + min_delta=0, + patience=0, + mode='auto', + baseline=None): + super(EarlyStoppingHandler, self).__init__(estimator) + + self._estimator = estimator + self.monitor = monitor + self.baseline = baseline + self.patience = patience + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + self.logger = logging.getLogger(__name__) + + if mode not in ['auto', 'min', 'max']: + warnings.warn(RuntimeWarning('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode.', mode)) + mode = 'auto' + + if mode == 'min': + self.monitor_op = np.less + elif mode == 'max': + self.monitor_op = np.greater + else: + if 'acc' in self.monitor: + self.monitor_op = np.greater + else: + self.monitor_op = np.less + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + def train_begin(self): + self.wait = 0 + self.stopped_epoch = 0 + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def epoch_end(self): + epoch = self._estimator.train_stats['epochs'][-1] + if self.monitor not in self._estimator.train_stats: + warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure' + 'you are passing one of the metric names as monitor', self.monitor)) + else: + current = self._estimator.train_stats[self.monitor][-1] + if current is None: + return + + if self.monitor_op(current - self.min_delta, self.best): + self.best = current + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self._estimator.stop_training = True + + def train_end(self): + if self.stopped_epoch > 0: + self.logger.info('Epoch %d: early stopping due to %s not improving', self.stopped_epoch, self.monitor) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py new file mode 100644 index 000000000000..a551594d6430 --- /dev/null +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +import mxnet as mx +from mxnet import nd +from mxnet.gluon import nn, loss +from mxnet.gluon.estimator import estimator, event_handler + +def _get_test_network(): + net = nn.Sequential() + net.add(nn.Dense(128, activation='relu', in_units=100, flatten=False), + nn.Dense(64, activation='relu', in_units=128), + nn.Dense(10, activation='relu', in_units=64)) + return net + +def _get_test_data(): + return mx.io.NDArrayIter(data=nd.ones((32, 100)), label=nd.random.randint(0, 10, (32, 1))) + + +def test_checkpoint_handler(): + tmpdir = tempfile.mkdtemp() + file_path = os.path.join(tmpdir, "model.params") + test_data = _get_test_data() + + save_best_only = False + mode = 'auto' + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + checkpoint_handler = [event_handler.CheckpointHandler(est, file_path, + save_best_only=save_best_only, + mode=mode)] + est.fit(test_data, event_handlers=checkpoint_handler, epochs=1) + assert os.path.isfile(file_path) + os.remove(file_path) + +def test_early_stopping(): + test_data = _get_test_data() + + mode = 'max' + monitor = 'train_accuracy' + patience = 0 + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + patience=patience, + mode=mode)] + est.fit(test_data, event_handlers=early_stopping, epochs=1) + + mode = 'auto' + monitor = 'train_accuracy' + patience = 2 + early_stopping = [event_handler.EarlyStoppingHandler(est, monitor, + patience=patience, + mode=mode)] + est.fit(test_data, event_handlers=early_stopping, epochs=1) + +def test_logging(): + tmpdir = tempfile.mkdtemp() + test_data = _get_test_data() + file_name = 'test_log' + output_dir = os.path.join(tmpdir, file_name) + + net = _get_test_network() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + est = estimator.Estimator(net, loss=ce_loss, metrics=acc) + logging_handler = [event_handler.LoggingHandler(est, file_name=file_name, file_location=tmpdir)] + est.fit(test_data, event_handlers=logging_handler, epochs=1) + assert os.path.isfile(output_dir) + os.remove(output_dir) \ No newline at end of file