Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix doc gen #12

Merged
merged 5 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions docs/tutorials/gluon/fit_api_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ est = estimator.Estimator(net=resnet_18_v1,
trainer=trainer,
context=ctx)

# Magic line
est.fit(train_data=train_data_loader,
# ignore warnings for nightly test on CI only
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Magic line
est.fit(train_data=train_data_loader,
epochs=num_epochs)
```

Expand Down Expand Up @@ -224,11 +228,15 @@ checkpoint_handler = CheckpointHandler(model_dir='./',
save_best=True) # Save the best model in terms of
# Let's instantiate another handler which we defined above
loss_record_handler = LossRecordHandler()
# Magic line
est.fit(train_data=train_data_loader,
val_data=val_data_loader,
epochs=num_epochs,
event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers
# ignore warnings for nightly test on CI only
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Magic line
est.fit(train_data=train_data_loader,
val_data=val_data_loader,
epochs=num_epochs,
event_handlers=[checkpoint_handler, loss_record_handler]) # Add the event handlers
```

Training begin: using optimizer SGD with current learning rate 0.0400 <!--notebook-skip-line-->
Expand Down
26 changes: 17 additions & 9 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,36 @@ def fit(self, train_data,
def _prepare_default_handlers(self, val_data, event_handlers):
event_handlers = event_handlers or []
default_handlers = []
train_metrics, val_metrics = self.prepare_loss_and_metrics()
self.prepare_loss_and_metrics()

# no need to add to default handler check as StoppingHandler does not use metrics
event_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))
default_handlers.append("StoppingHandler")

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
event_handlers.append(MetricHandler(train_metrics=train_metrics))
event_handlers.append(MetricHandler(train_metrics=self.train_metrics))
default_handlers.append("MetricHandler")

if val_data and not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=val_metrics))
default_handlers.append("ValidationHandler")
if not any(isinstance(handler, ValidationHandler) for handler in event_handlers):
# no validation handler
if val_data:
# add default validation handler if validation data found
event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate,
val_metrics=self.val_metrics))
default_handlers.append("ValidationHandler")
val_metrics = self.val_metrics
else:
# set validation metrics to None if no validation data and no validation handler
val_metrics = []

if not any(isinstance(handler, LoggingHandler) for handler in event_handlers):
event_handlers.append(LoggingHandler(train_metrics=train_metrics,
event_handlers.append(LoggingHandler(train_metrics=self.train_metrics,
val_metrics=val_metrics))
default_handlers.append("LoggingHandler")

# if there is a mix of user defined event handlers and default event handlers
# they should have the same set of loss and metrics
if default_handlers:
if default_handlers and len(event_handlers) > len(default_handlers):
msg = "You are training with the following default event handlers: %s. " \
"They use loss and metrics from estimator.prepare_loss_and_metrics(). " \
"Please use the same set of metrics for all your other handlers." % \
Expand All @@ -374,7 +382,7 @@ def _prepare_default_handlers(self, val_data, event_handlers):
# remove None metric references
references = set([ref for ref in references if ref])
for metric in references:
if metric not in train_metrics + val_metrics:
if metric not in self.train_metrics + self.val_metrics:
msg = "We have added following default handlers for you: %s and used " \
"estimator.prepare_loss_and_metrics() to pass metrics to " \
"those handlers. Please use the same set of metrics " \
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from ....metric import EvalMetric
from ....metric import Loss as metric_loss

__all__ = ['StoppingHandler', 'MetricHandler', 'ValidationHandler',
__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd','BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']

class TrainBegin(object):
Expand Down Expand Up @@ -513,8 +514,8 @@ def _save_symbol(self, estimator):
sym = estimator.net._cached_graph[1]
sym.save(symbol_file)
else:
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock"
"to construct your model, can call net.hybridize() before passing to"
self.logger.info("Model architecture(symbol file) is not saved, please use HybridBlock "
"to construct your model, can call net.hybridize() before passing to "
"Estimator in order to save model architecture as %s.", symbol_file)

def _save_params_and_trainer(self, estimator, file_prefix):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import sys
import unittest
import warnings

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.contrib.estimator import *
from mxnet.gluon.contrib.estimator.event_handler import *
from nose.tools import assert_raises


Expand Down