Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose s…
Browse files Browse the repository at this point in the history
…upport for Gluon fit() API (#14587)

* Retrieve Batch size and Logging verbose support for Gluon fit() API

* NIT changes

* Addressed review comments: shifted the batch size code to a separate method, sentence correction

* Modified unittest

* removed redundant parameter

* Resolve CI test failure

* only support DataLoader for now, future PRs will include DataIter to DataLoader converter

* Get the number of samples from shape attribute instead of length due to low space complexity

* Simplified batch size retrieval code

* removed batch_size parameter from fit() method and fixed the tests

* Verbose exception handling

* Assigning constant to a verbose

* Modified exception message

* Resolved undefined class reference

* Addressed review comments: Modified verbose level names, docs, variable names

* Update estimator.py
  • Loading branch information
karan6181 authored and szha committed May 20, 2019
1 parent ab7039e commit d4f7744
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 81 deletions.
43 changes: 21 additions & 22 deletions python/mxnet/gluon/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@

import copy
import warnings

from .event_handler import EventHandler, LoggingHandler
from ... import gluon, autograd
from ...context import Context, cpu, gpu, num_gpus
from ...io import DataIter
from ...metric import EvalMetric, Loss, Accuracy

__all__ = ['Estimator']
Expand Down Expand Up @@ -168,7 +166,7 @@ def evaluate(self,
Parameters
----------
val_data : DataLoader or DataIter
val_data : DataLoader
validation data with data and labels
batch_fn : function
custom batch function to extract data and label
Expand All @@ -182,13 +180,10 @@ def evaluate(self,
if not batch_fn:
if isinstance(val_data, gluon.data.DataLoader):
data, label = self._batch_fn(batch, self.context)
elif isinstance(val_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
"can provide the data as gluon.data.DataLoader.")
else:
data, label = batch_fn(batch, self.context)
pred = [self.net(x) for x in data]
Expand All @@ -208,16 +203,17 @@ def evaluate(self,
def fit(self, train_data,
val_data=None,
epochs=1,
batch_size=None,
event_handlers=None,
batch_fn=None):
"""Main training loop
"""Trains the model on a given dataset for a specified
number of epochs. Also, the batch size is inferred from the
DataLoader's batch_size.
Parameters
----------
train_data : DataLoader or DataIter
train_data : DataLoader
training data with data and labels
val_data : DataLoader or DataIter
val_data : DataLoader
validation data with data and labels
epochs : int, default 1
number of epochs to iterate on the training data.
Expand All @@ -232,19 +228,18 @@ def fit(self, train_data,
"""

self.max_epoch = epochs
if not batch_size:
self.batch_size = 32 * len(self.context)
else:
self.batch_size = batch_size
self.stop_training = False
self.samples = None
self.processed_samples = None
self.batch_idx = 0

event_handlers = event_handlers or []
# provide default logging handler
if not event_handlers or \
not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler())
warnings.warn("No Event Handler specified, default `LoggingHandler()` "
"is used with verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
"Please look at gluon.estimator.event_handler for more detail.")

train_begin, epoch_begin, batch_begin, \
batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers)
Expand All @@ -261,6 +256,8 @@ def fit(self, train_data,
for epoch in range(self.max_epoch):
# epoch begin
self.current_epoch = epoch
# Number of samples trained after every batch
completed_samples = 0

for handler in epoch_begin:
handler.epoch_begin()
Expand All @@ -272,16 +269,15 @@ def fit(self, train_data,
if not batch_fn:
if isinstance(train_data, gluon.data.DataLoader):
data, label = self._batch_fn(batch, self.context)
elif isinstance(train_data, DataIter):
data, label = self._batch_fn(batch, self.context, is_iterator=True)
else:
raise ValueError("You are using a custom iteration, please also provide "
"batch_fn to extract data and label. Alternatively, you "
"can provide the data as gluon.data.DataLoader or "
"mx.io.DataIter")
"can provide the data as gluon.data.DataLoader")
else:
data, label = batch_fn(batch, self.context)

batch_size = batch[0].shape[0]

# batch begin
for handler in batch_begin:
handler.batch_begin()
Expand Down Expand Up @@ -309,12 +305,15 @@ def fit(self, train_data,
name, value = loss_metric.get()
self.train_stats['train_' + name] = value

completed_samples += batch_size

self.batch_idx = i
# record trained samples v.s. total samples if using Gluon DataLoader
if isinstance(train_data, gluon.data.DataLoader):
self.samples = "{}/{}".format(self.batch_size * (i + 1), len(train_data._dataset))
self.processed_samples = "{}/{}".format(completed_samples,
len(train_data._dataset))

self.trainer.step(self.batch_size)
self.trainer.step(batch_size)
# batch end
for handler in batch_end:
handler.batch_end()
Expand Down
61 changes: 39 additions & 22 deletions python/mxnet/gluon/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,27 @@ class LoggingHandler(EventHandler):
file name to save the logs
file_location: str
file location to save the logs
verbose: int, default LOG_VERBOSITY_PER_EPOCH
Limit the granularity of metrics displayed during training process
verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
"""

def __init__(self, file_name=None, file_location=None):
LOG_VERBOSITY_PER_EPOCH = 1
LOG_VERBOSITY_PER_BATCH = 2

def __init__(self, file_name=None, file_location=None, verbose=LOG_VERBOSITY_PER_EPOCH):
super(LoggingHandler, self).__init__()
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
self.logger.addHandler(stream_handler)
if verbose not in [self.LOG_VERBOSITY_PER_EPOCH, self.LOG_VERBOSITY_PER_BATCH]:
raise ValueError("verbose level must be either LOG_VERBOSITY_PER_EPOCH or "
"LOG_VERBOSITY_PER_BATCH, received %s. "
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH)"
% verbose)
self.verbose = verbose
# save logger to file only if file name or location is specified
if file_name or file_location:
file_name = file_name or 'estimator_log'
Expand All @@ -118,33 +131,37 @@ def train_end(self):
self.logger.info(msg)

def batch_begin(self):
self.batch_start = time.time()
if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
self.batch_start = time.time()

def batch_end(self):
batch_time = time.time() - self.batch_start
epoch = self.estimator.current_epoch
batch = self.estimator.batch_idx
msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
if self.estimator.samples:
msg += '[Samples %s] ' % (self.estimator.samples)
msg += 'time/batch: %.3fs ' % batch_time
for key in self.estimator.train_stats:
# only log current training loss & metric after each batch
if key.startswith('train_'):
msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key]
self.logger.info(msg)
if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
batch_time = time.time() - self.batch_start
epoch = self.estimator.current_epoch
batch = self.estimator.batch_idx
msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
if self.estimator.processed_samples:
msg += '[Samples %s] ' % (self.estimator.processed_samples)
msg += 'time/batch: %.3fs ' % batch_time
for key in self.estimator.train_stats:
# only log current training loss & metric after each batch
if key.startswith('train_'):
msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key]
self.logger.info(msg)

def epoch_begin(self):
self.epoch_start = time.time()
if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
self.epoch_start = time.time()

def epoch_end(self):
epoch_time = time.time() - self.epoch_start
epoch = self.estimator.current_epoch
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
# log every result in train stats including train/validation loss & metrics
for key in self.estimator.train_stats:
msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
self.logger.info(msg)
if self.verbose >= self.LOG_VERBOSITY_PER_EPOCH:
epoch_time = time.time() - self.epoch_start
epoch = self.estimator.current_epoch
msg = '\n[Epoch %d] finished in %.3fs: ' % (epoch, epoch_time)
# log every result in train stats including train/validation loss & metrics
for key in self.estimator.train_stats:
msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
self.logger.info(msg)


class CheckpointHandler(EventHandler):
Expand Down
12 changes: 5 additions & 7 deletions tests/nightly/estimator/test_estimator_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ def test_estimator_cpu():
est = estimator.Estimator(net=net,
loss=loss,
metrics=mx.metric.Accuracy(),
trainers=trainer,
trainer=trainer,
context=context)
# Call fit()
est.fit(train_data=train_data,
val_data=val_data,
epochs=1,
batch_size=1)
epochs=1)

def test_estimator_gpu():
'''
Expand All @@ -131,15 +130,14 @@ def test_estimator_gpu():
est = estimator.Estimator(net=net,
loss=loss,
metrics=acc,
trainers=trainer,
trainer=trainer,
context=context)
# Call fit()
est.fit(train_data=train_data,
val_data=test_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)

assert est.train_stats['train_'+acc.name][num_epochs-1] > 0.80
assert est.train_stats['train_'+acc.name] > 0.80

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test gluon estimator')
Expand Down
6 changes: 3 additions & 3 deletions tests/nightly/estimator/test_sentiment_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def run(net, train_dataloader, test_dataloader, **kwargs):

# Define estimator
est = estimator.Estimator(net=net, loss=loss, metrics=acc,
trainers=trainer, context=ctx)
trainer=trainer, context=ctx)
# Begin training
est.fit(train_data=train_dataloader, val_data=test_dataloader,
epochs=num_epochs, batch_size=batch_size)
epochs=num_epochs)
return est


Expand Down Expand Up @@ -252,7 +252,7 @@ def test_estimator_gpu(**kwargs):

est = run(net, train_dataloader, test_dataloader, **kwargs)

assert est.train_stats['train_accuracy'][num_epochs - 1] > 0.70
assert est.train_stats['train_accuracy'] > 0.70


parser = argparse.ArgumentParser(description='test gluon estimator')
Expand Down
45 changes: 19 additions & 26 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,18 @@ def test_fit():
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
train_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
est.fit(train_data=train_dataloader,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)

# Input dataiter
train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
est.fit(train_data=train_dataiter,
epochs=num_epochs,
batch_size=batch_size)
with assert_raises(ValueError):
est.fit(train_data=train_dataiter,
epochs=num_epochs)

# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[in_data, out_data],
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)


def test_validation():
Expand All @@ -94,22 +92,20 @@ def test_validation():
val_dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
est.fit(train_data=train_dataloader,
val_data=val_dataloader,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)

# Input dataiter
train_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
val_dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
est.fit(train_data=train_dataiter,
val_data=val_dataiter,
epochs=num_epochs,
batch_size=batch_size)
with assert_raises(ValueError):
est.fit(train_data=train_dataiter,
val_data=val_dataiter,
epochs=num_epochs)
# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[in_data, out_data],
val_data=[in_data, out_data],
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)


@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
Expand All @@ -131,8 +127,7 @@ def test_initializer():
metrics=acc,
context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)

# different initializer for net and estimator
net = get_model()
Expand All @@ -148,8 +143,7 @@ def test_initializer():
context=ctx)
assert 'Network already initialized' in str(w[-1].message)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)


@unittest.skipIf(sys.version_info.major < 3, 'Test on python 3')
Expand All @@ -174,8 +168,7 @@ def test_trainer():
context=ctx)
assert 'No trainer specified' in str(w[-1].message)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)

# input invalid trainer
trainer = 'sgd'
Expand Down Expand Up @@ -206,8 +199,7 @@ def test_metric():
trainer=trainer,
context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)
# input list of metrics
metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()]
est = Estimator(net=net,
Expand All @@ -216,8 +208,7 @@ def test_metric():
trainer=trainer,
context=ctx)
est.fit(train_data=train_data,
epochs=num_epochs,
batch_size=batch_size)
epochs=num_epochs)
# input invalid metric
with assert_raises(ValueError):
est = Estimator(net=net,
Expand Down Expand Up @@ -260,7 +251,9 @@ def test_context():
loss=loss,
metrics=metrics)
# input list of context
ctx = [mx.gpu(0), mx.gpu(1)]
gpus = mx.context.num_gpus()
ctx = [mx.gpu(i) for i in gpus] if gpus > 0 else [mx.cpu()]
net = get_model()
est = Estimator(net=net,
loss=loss,
metrics=metrics,
Expand Down
Loading

0 comments on commit d4f7744

Please sign in to comment.