-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNet-1340][Fit API]Update train stats #14494
Conversation
if isinstance(self._optimizer, opt.Optimizer): | ||
return self._optimizer | ||
else: | ||
raise UserWarning("Optimizer has not been initialized yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the user sets a custom optimizer here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a custom optimizer should still inherit the base Optimizer class. Gluon trainer does the check: /~https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/trainer.py#L123
self.trainer = gluon.Trainer(self.net.collect_params(), | ||
'sgd', {'learning_rate': 0.001}) | ||
elif not isinstance(trainer, gluon.Trainer): | ||
raise ValueError("Trainer must be a Gluon Trainer instance, refer to gluon.trainer") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you say refer to gluon.trainer, you could probably also add a url to gluon.trainer docs : http://mxnet.incubator.apache.org/versions/master/api/python/gluon/gluon.html#trainer
self.train_stats['step'] = i | ||
self.train_history.batch_idx = i | ||
# record trained samples v.s. total samples if using Gluon DataLoader | ||
if isinstance(train_data, gluon.data.DataLoader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might need to rebase this line with the fit-api branch again.
/~https://github.com/apache/incubator-mxnet/blob/fit-api/python/mxnet/gluon/estimator/estimator.py#L245
@mxnet-label-bot add[pr-awaiting-review, Gluon] |
for handler in event_handlers: | ||
handler.estimator = self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy This will avoid to ask user passing estimator during event handler construction, reference: #14462 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering how the user of handler will know that an estimator will be initialized here? Also can you have a setter and getter for the estimator in Handler and not call handler.setEstimator(e) if handler.getEstimator() is not None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy when user call est.fit(xxx, event_handlers=XXX), this will already associate the event handlers with an estimator instance. I m just helping the user to pass this estimator so they don't need to do so during event handler construction.
The getter and setter are already implemented through the property interface. handler.estimator=self
is actually the setter method of property estimator.
handler.train_end() | ||
|
||
def categorize_handlers(self, event_handlers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit-> categorize_handlers
-> _ categorize_handlers
. Don't think we need this exposed to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
batch_end = [] | ||
epoch_end = [] | ||
train_end = [] | ||
base_handler = EventHandler() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are checking class methods, do you need to create an instance here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
metrics=None, | ||
initializer=None, | ||
trainers=None, | ||
trainer=None, | ||
context=None): | ||
|
||
self.net = net |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to set self._estimator = None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the estimator class, only event handlers should have self._estimator?
for handler in event_handlers: | ||
handler.estimator = self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering how the user of handler will know that an estimator will be initialized here? Also can you have a setter and getter for the estimator in Handler and not call handler.setEstimator(e) if handler.getEstimator() is not None.
@nswamy I have addressed the comments, could you take another look? thanks! |
@roywei Can you look at the CI failures ? |
@piyushghai it's all due to R package failure, i think it's ok. We will rebase before merging to master, and hopefully the R test will pass. |
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* [MXNet-1334][Fit API]base class for estimator and eventhandler (#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (#14633) * move to gluon contrib (#14635) * [Fit API] improve event handlers (#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (apache#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (apache#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (apache#14633) * move to gluon contrib (apache#14635) * [Fit API] improve event handlers (apache#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (apache#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (apache#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (apache#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
Description
In the previous Fit-API design doc , training statistics was stored in a dictionary, some values are stored as a list such learning rates and training accuraccy over epochs. users has to understand the underlying data structure inorder to access the statistics.
This PR improves how event handlers access train stats, and reduced empty method call in event handlers to improve efficiency.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments