-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1340][Fit API]Update train stats #14494
Changes from all commits
d71eba9
a8c2c7f
29c68b4
ad4041b
b8ec43c
904cdc7
c6fe873
6e690ca
55b102e
feac6e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
import copy | ||
import warnings | ||
|
||
from .event_handler import LoggingHandler | ||
from .event_handler import EventHandler, LoggingHandler | ||
from ... import gluon, autograd | ||
from ...context import Context, cpu, gpu, num_gpus | ||
from ...io import DataIter | ||
|
@@ -39,27 +39,26 @@ class Estimator(object): | |
Parameters | ||
---------- | ||
loss : Loss or list of Loss | ||
loss : gluon.loss.Loss or list of gluon.loss.Loss | ||
Loss(objective functions) to calculate during training | ||
metrics : EvalMetric or list of EvalMetric | ||
Metrics for evaluating models | ||
initializer : Initializer | ||
initializer to initialize the network | ||
trainers : Trainer or list of Trainer | ||
Trainers to apply optimizers on network parameters | ||
trainer : Trainer | ||
Trainer to apply optimizer on network parameters | ||
context : Context or list of Context | ||
devices to run the training on | ||
""" | ||
|
||
def __init__(self, net, | ||
loss=None, | ||
loss, | ||
metrics=None, | ||
initializer=None, | ||
trainers=None, | ||
trainer=None, | ||
context=None): | ||
|
||
self.net = net | ||
self.stop_training = False | ||
|
||
if isinstance(loss, gluon.loss.Loss): | ||
self.loss = [loss] | ||
|
@@ -86,27 +85,14 @@ def __init__(self, net, | |
|
||
# store training statistics | ||
self.train_stats = {} | ||
self.train_stats['epochs'] = [] | ||
self.train_stats['learning_rate'] = [] | ||
# current step of the epoch | ||
self.train_stats['step'] = '' | ||
for metric in self.train_metrics: | ||
# record a history of metrics over each epoch | ||
self.train_stats['train_' + metric.name] = [] | ||
# only record the latest metric numbers after each batch | ||
self.train_stats['batch_' + metric.name] = 0. | ||
for metric in self.val_metrics: | ||
self.train_stats['val_' + metric.name] = [] | ||
|
||
# separate train and validation | ||
self.train_loss_metrics = [] | ||
self.val_loss_metrics = [] | ||
# using the metric wrapper for loss to record loss value | ||
for l in self.loss: | ||
self.train_loss_metrics.append(Loss(l.name)) | ||
self.val_loss_metrics.append(Loss(l.name)) | ||
self.train_stats['train_' + l.name] = [] | ||
self.train_stats['val_' + l.name] = [] | ||
# only record the latest loss numbers after each batch | ||
self.train_stats['batch_' + l.name] = 0. | ||
|
||
# handle context | ||
if isinstance(context, Context): | ||
|
@@ -127,15 +113,14 @@ def __init__(self, net, | |
raise ValueError("context must be a Context or a list of Context, " | ||
"refer to mxnet.Context:{}".format(context)) | ||
|
||
|
||
# initialize the network | ||
self.initializer = initializer | ||
if self.initializer: | ||
if self._is_initialized(): | ||
# if already initialized, re-init with user specified initializer | ||
warnings.warn("Network already initialized, re-initializing with %s. " | ||
"You don't need to pass initializer if you already " | ||
"initialized your net."% type(self.initializer).__name__) | ||
"initialized your net." % type(self.initializer).__name__) | ||
self.net.initialize(init=self.initializer, ctx=self.context, force_reinit=True) | ||
else: | ||
# initialize with user specified initializer | ||
|
@@ -144,16 +129,17 @@ def __init__(self, net, | |
if not self._is_initialized(): | ||
self.net.initialize(ctx=self.context) | ||
|
||
# handle trainers | ||
if isinstance(trainers, gluon.Trainer): | ||
self.trainers = [trainers] | ||
elif not trainers: | ||
# handle trainer | ||
if not trainer: | ||
warnings.warn("No trainer specified, default SGD optimizer " | ||
"with learning rate 0.001 is used.") | ||
self.trainers = [gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001})] | ||
self.trainer = gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001}) | ||
elif not isinstance(trainer, gluon.Trainer): | ||
raise ValueError("Trainer must be a Gluon Trainer instance, refer to " | ||
"gluon.Trainer:{}".format(trainer)) | ||
else: | ||
raise ValueError("Invalid trainer specified, please provide a valid gluon.Trainer") | ||
self.trainer = trainer | ||
|
||
def _is_initialized(self): | ||
param_dict = self.net.collect_params() | ||
|
@@ -212,8 +198,12 @@ def evaluate(self, | |
# update metrics | ||
for metric in self.val_metrics: | ||
metric.update(label, pred) | ||
name, value = metric.get() | ||
self.train_stats['val_' + name] = value | ||
for loss, loss_metric, in zip(losses, self.val_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
name, value = loss_metric.get() | ||
self.train_stats['val_' + name] = value | ||
|
||
def fit(self, train_data, | ||
val_data=None, | ||
|
@@ -241,27 +231,38 @@ def fit(self, train_data, | |
from a data batch and load into contexts(devices) | ||
""" | ||
|
||
|
||
self.epochs = epochs | ||
self.max_epoch = epochs | ||
if not batch_size: | ||
batch_size = 32 * len(self.context) | ||
self.batch_size = 32 * len(self.context) | ||
else: | ||
self.batch_size = batch_size | ||
self.stop_training = False | ||
self.samples = None | ||
self.batch_idx = 0 | ||
|
||
event_handlers = event_handlers or [] | ||
# provide default logging handler | ||
if not event_handlers or \ | ||
not any(isinstance(handler, LoggingHandler) for handler in event_handlers): | ||
event_handlers.append(LoggingHandler(self)) | ||
event_handlers.append(LoggingHandler()) | ||
|
||
# training begin | ||
train_begin, epoch_begin, batch_begin, \ | ||
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) | ||
|
||
# passing estimator to event handlers so they can access estimator information | ||
# when a event is triggered | ||
for handler in event_handlers: | ||
handler.estimator = self | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nswamy This will avoid to ask user passing estimator during event handler construction, reference: #14462 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering how the user of handler will know that an estimator will be initialized here? Also can you have a setter and getter for the estimator in Handler and not call handler.setEstimator(e) if handler.getEstimator() is not None. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nswamy when user call est.fit(xxx, event_handlers=XXX), this will already associate the event handlers with an estimator instance. I m just helping the user to pass this estimator so they don't need to do so during event handler construction. |
||
|
||
# training begin | ||
for handler in train_begin: | ||
handler.train_begin() | ||
|
||
for epoch in range(epochs): | ||
for epoch in range(self.max_epoch): | ||
# epoch begin | ||
self.train_stats['epochs'].append(epoch) | ||
self.train_stats['learning_rate'].append(self.trainers[0].learning_rate) | ||
self.current_epoch = epoch | ||
|
||
for handler in event_handlers: | ||
for handler in epoch_begin: | ||
handler.epoch_begin() | ||
|
||
for metric in self.train_metrics + self.train_loss_metrics: | ||
|
@@ -282,7 +283,7 @@ def fit(self, train_data, | |
data, label = batch_fn(batch, self.context) | ||
|
||
# batch begin | ||
for handler in event_handlers: | ||
for handler in batch_begin: | ||
handler.batch_begin() | ||
|
||
with autograd.record(): | ||
|
@@ -298,42 +299,64 @@ def fit(self, train_data, | |
# update train metrics | ||
for metric in self.train_metrics: | ||
metric.update(label, pred) | ||
self.train_stats['batch_' + metric.name] = metric.get()[1] | ||
# get metric name and current value and update train stats | ||
name, value = metric.get() | ||
self.train_stats['train_' + name] = value | ||
|
||
# update loss | ||
for loss, loss_metric, in zip(losses, self.train_loss_metrics): | ||
loss_metric.update(0, [l for l in loss]) | ||
self.train_stats['batch_' + loss_metric.name] = loss_metric.get()[1] | ||
|
||
try: | ||
completed_samples = len(train_data._dataset) if i == len(train_data._dataset) - 1 \ | ||
else batch_size * (i + 1) | ||
# We need to check if this is the last batch in the current epoch and select | ||
# the value to print appropriately | ||
self.train_stats['step'] = "{}/{}".format(completed_samples, len(train_data._dataset)) | ||
except AttributeError: | ||
self.train_stats['step'] = i | ||
name, value = loss_metric.get() | ||
self.train_stats['train_' + name] = value | ||
|
||
for trainer in self.trainers: | ||
trainer.step(batch_size) | ||
self.batch_idx = i | ||
# record trained samples v.s. total samples if using Gluon DataLoader | ||
if isinstance(train_data, gluon.data.DataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might need to rebase this line with the fit-api branch again. |
||
self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset)) | ||
|
||
self.trainer.step(self.batch_size) | ||
# batch end | ||
for handler in event_handlers: | ||
for handler in batch_end: | ||
handler.batch_end() | ||
|
||
if val_data: | ||
self.evaluate(val_data, batch_fn) | ||
|
||
for metric in self.train_metrics + self.train_loss_metrics: | ||
self.train_stats['train_' + metric.name].append(metric.get()[1]) | ||
for metric in self.val_metrics + self.val_loss_metrics: | ||
self.train_stats['val_' + metric.name].append(metric.get()[1]) | ||
|
||
# epoch end | ||
for handler in event_handlers: | ||
for handler in epoch_end: | ||
handler.epoch_end() | ||
|
||
if self.stop_training: | ||
break | ||
|
||
# train end | ||
for handler in event_handlers: | ||
for handler in train_end: | ||
handler.train_end() | ||
|
||
def _categorize_handlers(self, event_handlers): | ||
""" | ||
categorize handlers into 6 event lists to avoid calling empty methods | ||
for example, only event handlers with train_begin method | ||
implemented will be called at train begin | ||
""" | ||
|
||
train_begin = [] | ||
epoch_begin = [] | ||
batch_begin = [] | ||
batch_end = [] | ||
epoch_end = [] | ||
train_end = [] | ||
for handler in event_handlers: | ||
if not handler.__class__.train_begin == EventHandler.train_begin: | ||
train_begin.append(handler) | ||
if not handler.__class__.epoch_begin == EventHandler.epoch_begin: | ||
epoch_begin.append(handler) | ||
if not handler.__class__.batch_begin == EventHandler.batch_begin: | ||
batch_begin.append(handler) | ||
if not handler.__class__.batch_end == EventHandler.batch_end: | ||
batch_end.append(handler) | ||
if not handler.__class__.epoch_end == EventHandler.epoch_end: | ||
epoch_end.append(handler) | ||
if not handler.__class__.train_end == EventHandler.train_end: | ||
train_end.append(handler) | ||
return train_begin, epoch_begin, batch_begin, batch_end, epoch_end, train_end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to set self._estimator = None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the estimator class, only event handlers should have self._estimator?