From fd201b3e959139b39cd504133c84a77eaeda2d54 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 20 Sep 2015 11:52:16 -0700 Subject: [PATCH] Update model to add save/load and period checkpoint. Removed wait for All at threaded Engine. Still do not know the cause of WaitforVar stalling. --- python/mxnet/__init__.py | 3 +- python/mxnet/base.py | 41 +++ python/mxnet/io.py | 15 +- python/mxnet/metric.py | 4 +- python/mxnet/model.py | 371 +++++++++++++++++++++------ python/mxnet/ndarray.py | 16 +- python/mxnet/optimizer.py | 19 +- python/mxnet/symbol.py | 20 +- src/engine/threaded_engine_pooled.cc | 3 - tests/python/train/test_mlp.py | 54 +++- tests/python/train/test_mlp_old.py | 107 -------- 11 files changed, 411 insertions(+), 242 deletions(-) delete mode 100644 tests/python/train/test_mlp_old.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index b5429a7bd816..89dd2c09da79 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -22,6 +22,7 @@ from . import model from . import initializer from . import visualization -import atexit +# use viz as short for mx.ndarray +from . import visualization as viz __version__ = "0.1.0" diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 2f7d15919681..54d128d83e78 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -179,3 +179,44 @@ def ctypes2numpy_shared(cptr, shape): dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) + +def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True): + """Convert ctypes returned doc string information into parameters docstring. + + num_args : mx_uint + Number of arguments. + + arg_names : ctypes.POINTER(ctypes.c_char_p) + Argument names. + + arg_types : ctypes.POINTER(ctypes.c_char_p) + Argument type information. + + arg_descs : ctypes.POINTER(ctypes.c_char_p) + Argument description information. + + remove_dup : boolean, optional + Whether remove duplication or not. + + Returns + ------- + docstr : str + Python docstring of parameter sections. + """ + param_keys = set() + param_str = [] + for i in range(num_args.value): + key = py_str(arg_names[i]) + if key in param_keys and remove_dup: + continue + param_keys.add(key) + type_info = py_str(arg_types[i]) + ret = '%s : %s' % (key, type_info) + if len(arg_descs[i]) != 0: + ret += '\n ' + py_str(arg_descs[i]) + param_str.append(ret) + doc_str = ('Parameters\n' + + '----------\n' + + '%s\n') + doc_str = doc_str % ('\n'.join(param_str)) + return doc_str diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 62e92bd020d5..cb55df71aa3f 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -8,7 +8,7 @@ from .base import _LIB from .base import c_array, c_str, mx_uint, py_str from .base import DataIterHandle, NDArrayHandle -from .base import check_call +from .base import check_call, ctypes2docstring from .ndarray import NDArray class DataIter(object): @@ -99,24 +99,17 @@ def _make_io_iterator(handle): ctypes.byref(arg_types), \ ctypes.byref(arg_descs))) iter_name = py_str(name.value) - param_str = [] - for i in range(num_args.value): - ret = '%s : %s' % (arg_names[i], arg_types[i]) - if len(arg_descs[i]) != 0: - ret += '\n ' + py_str(arg_descs[i]) - param_str.append(ret) + param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) doc_str = ('%s\n\n' + - 'Parameters\n' + - '----------\n' + '%s\n' + 'name : string, required.\n' + ' Name of the resulting data iterator.\n\n' + 'Returns\n' + '-------\n' + - 'iterator: Iterator\n'+ + 'iterator: DataIter\n'+ ' The result iterator.') - doc_str = doc_str % (desc.value, '\n'.join(param_str)) + doc_str = doc_str % (desc.value, param_str) def creator(*args, **kwargs): """Create an iterator. diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index ff100c12c191..ad0aa55d332a 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -47,8 +47,8 @@ def __init__(self): def update(self, pred, label): pred = pred.asnumpy() label = label.asnumpy().astype('int32') - y = np.argmax(pred, axis=1) - self.sum_metric += np.sum(y == label) + py = np.argmax(pred, axis=1) + self.sum_metric += np.sum(py == label) self.num_inst += label.size diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 726be0d7eb45..d757ce5ded27 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -1,13 +1,15 @@ -# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, no-member -# pylint: disable=too-many-branches, too-many-statements, unused-argument, unused-variable +# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals +# pylint: disable=too-many-branches, too-many-statements, unused-argument """MXNet model module""" import numpy as np import time +import logging from . import io from . import nd +from . import symbol as sym from . import optimizer as opt from . import metric -from .context import Context +from .context import Context, cpu from .initializer import Xavier @@ -20,11 +22,63 @@ SKLEARN_INSTALLED = False +def _check_arguments(symbol): + """Check the argument names of symbol. + + This function checks the duplication of arguments in Symbol. + The check is done for feedforward net for now. + + Parameters + ---------- + symbol : Symbol + The network configuration + + Returns + ------- + data_index : int + Index position of data. + label_index : int + Index position of label + """ + arg_names = symbol.list_arguments() + data_index, label_index = None, None + arg_set = set() + for index, name in enumerate(arg_names): + if name.endswith('label'): + if label_index is not None: + raise ValueError('Two arguments with suffix \"label\", ' + + 'only accept one label in config for now, '+ + 'arguments are %s' % str(arg_names)) + label_index = index + if name.endswith('data'): + if data_index is not None: + raise ValueError('Two arguments with suffix \"label\", ' + + 'only accept one input data in config for now, ' + + 'arguments are %s' % str(arg_names)) + data_index = index + if name in arg_set: + raise ValueError(('Find duplicated argument name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s') % (name, str(arg_names))) + arg_set.add(name) + + aux_names = symbol.list_auxiliary_states() + for name in aux_names: + if name in arg_set: + raise ValueError( + ('Find duplicated auxiliary param name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s, auxiliary params are %s' + ) % (name, str(arg_names), str(aux_names))) + + return (data_index, label_index) + + def _train(symbol, ctx, input_shape, arg_params, aux_params, begin_round, end_round, optimizer, train_data, eval_data=None, eval_metric=None, - iter_end_callback=None, verbose=True): + iter_end_callback=None, logger=None): """Inernal training function. Parameters @@ -62,18 +116,20 @@ def _train(symbol, ctx, input_shape, eval_metric : EvalMetric A evaluation function. - iter_end_callback : callable(iteration, arg_params, aux_states) + iter_end_callback : callable(iteration, symbol, arg_params, aux_states) A callback that is invoked at end of each iteration. This can be used to checkpoint model each iteration. - verbose : boolean - Whether print message during training. + logger : logging logger + When not specified, default logger will be used. Notes ----- This function will inplace update the NDArrays in arg_parans and aux_states. """ assert(len(ctx) == 1) + if logger is None: + logger = logging # bind the symbol train_exec = symbol.simple_bind(ctx[0], data=input_shape, grad_req='write') arg_names = symbol.list_arguments() @@ -82,32 +138,20 @@ def _train(symbol, ctx, input_shape, grad_arrays = train_exec.grad_arrays aux_arrays = train_exec.aux_arrays # copy initialized parameters to executor parameters - for key, weight in list(zip(arg_names, arg_arrays)): + for key, weight in zip(arg_names, arg_arrays): if key in arg_params: arg_params[key].copyto(weight) - for key, weight in list(zip(aux_names, aux_arrays)): + for key, weight in zip(aux_names, aux_arrays): if key in aux_params: aux_params[key].copyto(weight) # setup helper data structures - label_array = None - data_array = None - for name, arr in list(zip(symbol.list_arguments(), arg_arrays)): - if name.endswith('label'): - assert label_array is None - label_array = arr - if name.endswith('data'): - assert data_array is None - data_array = arr - assert data_array is not None - assert label_array is not None - + data_index, label_index = _check_arguments(symbol) + data_array, label_array = arg_arrays[data_index], arg_arrays[label_index] out_array = train_exec.outputs[0] out_cpu_array = nd.zeros(out_array.shape) - arg_blocks = list(zip(arg_names, arg_arrays, grad_arrays)) + arg_blocks = zip(arg_arrays, grad_arrays) for i in range(begin_round, end_round): - if verbose: - print("Epoch %d:" % i) # training phase tic = time.time() train_data.reset() @@ -121,18 +165,18 @@ def _train(symbol, ctx, input_shape, out_array.copyto(out_cpu_array) train_exec.backward() # update the parameters - for key, weight, grad in arg_blocks: + for index, block in enumerate(arg_blocks): + weight, grad = block if grad is not None: - optimizer.update(key, weight, grad) + optimizer.update(index, weight, grad) # evaluate at end, so out_cpu_array can lazy copy eval_metric.update(out_cpu_array, label) name, value = eval_metric.get() - print ('Train %s:\t%f' % (name, value)) - + logger.info('Iteration[%d] Train-%s=%f', i, name, value) toc = time.time() - if verbose: - print("Time: %.3f" % (toc - tic)) + logger.info('Iteration[%d] Time cost=%.3f', i, (toc - tic)) + # evaluation phase if eval_data is not None: eval_metric.reset() @@ -145,21 +189,113 @@ def _train(symbol, ctx, input_shape, eval_metric.update(out_array, label) name, value = eval_metric.get() - print ('Validation %s:\t%f' % (name, value)) + logger.info('Iteration[%d] Validation-%s=%f', i, name, value) if iter_end_callback or i + 1 == end_round: # copy data back to cpu - for key, weight, gard in arg_blocks: + for key, weight in zip(arg_names, arg_arrays): if key in arg_params: weight.copyto(arg_params[key]) - for key, arr in list(zip(aux_names, aux_arrays)): + for key, arr in zip(aux_names, aux_arrays): arr.copyto(aux_params[key]) if iter_end_callback: - iter_end_callback(i, arg_params, aux_arrays) + iter_end_callback(i, symbol, arg_params, aux_params) # end of the function return +def save_checkpoint(prefix, iteration, symbol, arg_params, aux_params): + """Checkpoint the model data into file. + + Parameters + ---------- + prefix : str + Prefix of model name. + + iteration : int + The iteration number of the model. + + symbol : Symbol + The input symbol + + arg_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's weights. + + aux_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's auxiliary states. + + Notes + ----- + - ``prefix-symbol.json`` will be saved for symbol. + - ``prefix-iteration.params`` will be saved for parameters. + """ + symbol.save('%s-symbol.json' % prefix) + save_dict = {('arg:%s' % k) : v for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k) : v for k, v in aux_params.items()}) + param_name = '%s-%04d.params' % (prefix, iteration) + nd.save(param_name, save_dict) + logging.info('Saved checkpoint to \"%s\"', param_name) + + +def load_checkpoint(prefix, iteration): + """Load model checkpoint from file. + + Parameters + ---------- + prefix : str + Prefix of model name. + + iteration : int + Iteration number of model we would like to load. + + Returns + ------- + symbol : Symbol + The symbol configuration of computation network. + + arg_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's weights. + + aux_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's auxiliary states. + + Notes + ----- + - ``prefix-symbol.json`` will be saved for symbol. + - ``prefix-iteration.params`` will be saved for parameters. + """ + symbol = sym.load('%s-symbol.json' % prefix) + save_dict = nd.load('%s-%04d.params' % (prefix, iteration)) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + if tp == 'aux': + aux_params[name] = v + return (symbol, arg_params, aux_params) + + +def do_checkpoint(prefix): + """Callback to checkpoint the model to prefix every iteration. + + Parameters + ---------- + prefix : str + The file prefix to checkpoint to + + Returns + ------- + callback : function + The callback function that can be passed as iter_end_callback to fit. + """ + def _callback(iter_no, s, arg, aux): + """The checkpoint function.""" + save_checkpoint(prefix, iter_no + 1, s, arg, aux) + return _callback + + class FeedForward(BASE_ESTIMATOR): """Model class of MXNet for training and predicting feedforward nets. @@ -170,13 +306,10 @@ class FeedForward(BASE_ESTIMATOR): symbol : Symbol The symbol configuration of computation network. - ctx : Context or list of Context + ctx : Context or list of Context, optional The device context of training and prediction. To use multi GPU training, pass in a list of gpu contexts. - input_shape : tuple - Shape of input data batch. - num_round : int, optional Training parameter, number of training rounds(iterations). @@ -189,66 +322,86 @@ class FeedForward(BASE_ESTIMATOR): arg_params : dict of str to NDArray, optional Model parameter, dict of name to NDArray of net's weights. - aux_states : dict of str to NDArray, optional + aux_params : dict of str to NDArray, optional Model parameter, dict of name to NDArray of net's auxiliary states. **kwargs : dict The additional keyword arguments passed to optimizer. """ - def __init__(self, symbol, ctx, input_shape, + def __init__(self, symbol, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), - arg_params=None, aux_states=None, + arg_params=None, aux_params=None, **kwargs): # basic configuration self.symbol = symbol - if isinstance(ctx, Context): + if ctx is None: + ctx = [cpu()] + elif isinstance(ctx, Context): ctx = [ctx] self.ctx = ctx - self.input_shape = input_shape # training parameters self.num_round = num_round - if isinstance(optimizer, str): - batch_size = input_shape[0] - optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **kwargs) + self.kwargs = kwargs.copy() self.optimizer = optimizer self.initializer = initializer # model parameters self.arg_params = arg_params - self.aux_states = aux_states + self.aux_params = aux_params # internal helper state self._pred_exec = None self._pred_exec_input = None - def _init_params(self): + @staticmethod + def _is_data_arg(name): + """Check if name is a data argument.""" + return name.endswith('data') or name.endswith('label') + + @staticmethod + def _get_input_shape(data): + """Get input shape from data iterator.""" + data.reset() + data.next() + input_shape = data.getdata().shape + data.reset() + return input_shape + + def _init_params(self, input_shape): """Use initializer to initialize the parameters.""" - is_data_arg = lambda x: x.endswith('data') or x.endswith('label') - arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=self.input_shape) + arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=input_shape) if self.arg_params is None: arg_names = self.symbol.list_arguments() self.arg_params = {k : nd.zeros(s) for k, s in list(zip(arg_names, arg_shapes)) - if not is_data_arg(k)} - if self.aux_states is None: + if not self._is_data_arg(k)} + if self.aux_params is None: aux_names = self.symbol.list_auxiliary_states() - self.aux_states = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))} + self.aux_params = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))} for k, v in self.arg_params.items(): self.initializer(k, v) - for k, v in self.aux_states.items(): + for k, v in self.aux_params.items(): self.initializer(k, v) - def _init_predictor(self): + def __getstate__(self): + this = self.__dict__.copy() + this['_pred_exec'] = None + return this + + def __setstate__(self, state): + self.__dict__.update(state) + + def _init_predictor(self, input_shape): """Initialize the predictor module for running prediction.""" if self._pred_exec is not None: return # for now only use the first device pred_exec = self.symbol.simple_bind( - self.ctx[0], grad_req='null', data=self.input_shape) + self.ctx[0], grad_req='null', data=input_shape) + for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)): - if name not in self.arg_datas: + if not self._is_data_arg(name): assert name in self.arg_params self.arg_params[name].copyto(value) - else: - assert self._pred_exec_input is None - self._pred_exec_input = value + data_index, _ = _check_arguments(self.symbol) + self._pred_exec_input = pred_exec.arg_arrays[data_index] self._pred_exec = pred_exec def predict(self, X): @@ -264,15 +417,18 @@ def predict(self, X): The predicted value of the output. """ assert isinstance(X, io.DataIter) - self._init_predictor() + self._init_predictor(self._get_input_shape(X)) outputs = [] - for data, label in X: - data.copyto(self.pred_exec_input) + + X.reset() + for data, _ in X: + data.copyto(self._pred_exec_input) self._pred_exec.forward() - outputs.extend(self._pred_exec.outputs[0].asnumpy()) + outputs.append(self._pred_exec.outputs[0].asnumpy()) return np.concatenate(outputs) - def fit(self, X, y=None, eval_data=None, eval_metric='acc', verbose=True): + def fit(self, X, y=None, eval_data=None, eval_metric='acc', + iter_end_callback=None, logger=None): """fit the model Parameters @@ -289,18 +445,87 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', verbose=True): eval_metric : function Evaluation metric function. - verbose : boolean - Whether print information during training. + iter_end_callback : callable(iteration, symbol, arg_params, aux_states) + A callback that is invoked at end of each iteration. + This can be used to checkpoint model each iteration. + + logger : logging logger, optional + When not specified, default logger will be used. """ + input_shape = self._get_input_shape(X) if self.arg_params is None: - self._init_params() + self._init_params(input_shape) + # setup metric if isinstance(eval_metric, str): eval_metric = metric.create(eval_metric) - - _train(self.symbol, self.ctx, self.input_shape, - self.arg_params, self.aux_states, + # setup optimizer + optimizer = self.optimizer + if isinstance(optimizer, str): + batch_size = input_shape[0] + optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs)) + # do training + _train(self.symbol, self.ctx, input_shape, + self.arg_params, self.aux_params, begin_round=0, end_round=self.num_round, - optimizer=self.optimizer, + optimizer=optimizer, train_data=X, eval_data=eval_data, eval_metric=eval_metric, - verbose=verbose) + iter_end_callback=iter_end_callback, + logger=logger) + + def save(self, prefix, iteration=None): + """Checkpoint the model checkpoint into file. + + You can also use pickle to do the job if you only work on python. + The advantage of load/save is the file is language agnostic. + This means the file saved using save can be loaded by other language binding of mxnet. + You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) + + Parameters + ---------- + prefix : str + Prefix of model name. + + See Also + -------- + Symbol.load : the method to load the model back. + + Notes + ----- + - ``prefix-symbol.json`` will be saved for symbol. + - ``prefix-iteration.params`` will be saved for parameters. + """ + if iteration is None: + iteration = self.num_round + assert iteration is not None + save_checkpoint(prefix, iteration, self.symbol, self.arg_params, self.aux_params) + + @staticmethod + def load(prefix, iteration, ctx=None): + """Load model checkpoint from file. + + Parameters + ---------- + prefix : str + Prefix of model name. + + iteration : int + Iteration number of model we would like to load. + + ctx : Context or list of Context, optional + The device context of training and prediction. + + Returns + ------- + model : FeedForward + The loaded model that can be used for prediction. + + Notes + ----- + - ``prefix-symbol.json`` will be saved for symbol. + - ``prefix-iteration.params`` will be saved for parameters. + """ + symbol, arg_params, aux_params = load_checkpoint(prefix, iteration) + return FeedForward(symbol, ctx=ctx, + arg_params=arg_params, aux_params=aux_params) + diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 921c00fccb35..9af18b937b4c 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -10,7 +10,7 @@ from .base import c_array, py_str, c_str from .base import mx_uint, mx_float, NDArrayHandle, FunctionHandle from .base import ctypes2buffer -from .base import check_call +from .base import check_call, ctypes2docstring from .context import Context def _new_empty_handle(): @@ -36,7 +36,6 @@ def _new_alloc_handle(shape, ctx, delay_alloc): a new empty ndarray handle """ hdl = NDArrayHandle() - print ctx.device_typeid check_call(_LIB.MXNDArrayCreate( c_array(mx_uint, shape), len(shape), @@ -530,17 +529,8 @@ def _make_ndarray_function(handle): ctypes.byref(arg_types), ctypes.byref(arg_descs))) func_name = py_str(name.value) - - param_str = [] - for i in range(num_args.value): - ret = '%s : %s' % (py_str(arg_names[i]), py_str(arg_types[i])) - if len(arg_descs[i]) != 0: - ret += '\n ' + py_str(arg_descs[i]) - param_str.append(ret) - + param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) doc_str = ('%s\n\n' + - 'Parameters\n' + - '----------\n' + '%s\n' + 'out : NDArray, optional\n' + ' The output NDArray to hold the result.\n\n'+ @@ -548,7 +538,7 @@ def _make_ndarray_function(handle): '-------\n' + 'out : NDArray\n'+ ' The output of binary function.') - doc_str = doc_str % (py_str(desc.value), '\n'.join(param_str)) + doc_str = doc_str % (py_str(desc.value), param_str) # Definition of internal functions. def binary_ndarray_function(lhs, rhs, out=None): diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 8118e23f2bf6..8cc3d1b4f241 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -41,25 +41,28 @@ def __init__(self, learning_rate=0.01, momentum=0.0, self.rescale_grad = rescale_grad self.momentums = {} - def update(self, key, weight, grad): + def update(self, index, weight, grad): """Update the parameters. Parameters ---------- - key : str - The name of the parameter. - weight: NDArray + index : int + An unique integer key used to index the parameters + + weight : NDArray weight ndarray - grad: NDArray + + grad : NDArray grad ndarray + """ # TODO(bing) implement wd_bias, wd_gamma, wd_beta assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) - if key not in self.momentums: - self.momentums[key] = zeros(grad.shape, grad.context) - mom = self.momentums[key] + if index not in self.momentums: + self.momentums[index] = zeros(grad.shape, grad.context) + mom = self.momentums[index] mom[:] *= self.momentum mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight) weight[:] += mom diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 006bc66a5223..a7d53f28fc1a 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -12,7 +12,7 @@ from .base import _LIB from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NDArrayHandle, ExecutorHandle, SymbolHandle -from .base import check_call +from .base import check_call, ctypes2docstring from .context import Context from .ndarray import NDArray, zeros from .executor import Executor @@ -680,7 +680,6 @@ def _make_atomic_symbol_function(handle): arg_types = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( handle, ctypes.byref(name), ctypes.byref(desc), ctypes.byref(num_args), @@ -688,25 +687,14 @@ def _make_atomic_symbol_function(handle): ctypes.byref(arg_types), ctypes.byref(arg_descs), ctypes.byref(key_var_num_args))) + param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs) key_var_num_args = py_str(key_var_num_args.value) func_name = py_str(name.value) - param_str = [] - for i in range(num_args.value): - key = py_str(arg_names[i]) - if key == key_var_num_args: - continue - ret = '%s : %s' % (key, py_str(arg_types[i])) - if len(arg_descs[i]) != 0: - ret += '\n ' + py_str(arg_descs[i]) - param_str.append(ret) desc = py_str(desc.value) if key_var_num_args: - desc = '\nThis function support variable length of positional input.' - + desc += '\nThis function support variable length of positional input.' doc_str = ('%s\n\n' + - 'Parameters\n' + - '----------\n' + '%s\n' + 'name : string, required.\n' + ' Name of the resulting symbol.\n\n' + @@ -714,7 +702,7 @@ def _make_atomic_symbol_function(handle): '-------\n' + 'symbol: Symbol\n'+ ' The result symbol.') - doc_str = doc_str % (desc, '\n'.join(param_str)) + doc_str = doc_str % (desc, param_str) def creator(*args, **kwargs): """Activation Operator of Neural Net. diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index a7d44ea018a3..d806c382390c 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -28,9 +28,6 @@ class ThreadedEnginePooled : public ThreadedEngine { io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { - // wait until all the tasks are completed. - // TODO(hotpxl) think if this is the correct thing to do - this->WaitForAll(); streams_.Finalize(); task_queue_.SignalForKill(); io_task_queue_.SignalForKill(); diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 29e9448cb819..eb3502de958a 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -1,8 +1,9 @@ # pylint: skip-file import mxnet as mx import numpy as np -import os, gzip +import os, sys import pickle as pickle +import logging from common import get_data # symbol net @@ -15,12 +16,12 @@ fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) softmax = mx.symbol.Softmax(data = fc3, name = 'sm') -# infer shape -data_shape = (batch_size, 784) - -model = mx.model.FeedForward(softmax, mx.cpu(), data_shape, - num_round=9, learning_rate=0.1, wd=0.0004, - momentum=0) +num_round = 4 +prefix = './mlp' +model = mx.model.FeedForward(softmax, mx.cpu(), + num_round=num_round, + learning_rate=0.01, wd=0.0004, + momentum=0.9) #check data get_data.GetMNIST_ubyte() @@ -36,8 +37,45 @@ batch_size=batch_size, shuffle=True, flat=True, silent=False) def test_mlp(): + # print logging by default + logging.basicConfig(level=logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + logging.getLogger('').addHandler(console) + model.fit(X=train_dataiter, - eval_data=val_dataiter) + eval_data=val_dataiter, + iter_end_callback=mx.model.do_checkpoint(prefix)) + prob = model.predict(val_dataiter) + val_dataiter.reset() + y = np.concatenate([label.asnumpy() for _, label in val_dataiter]).astype('int') + py = np.argmax(prob, axis=1) + acc1 = float(np.sum(py == y)) / len(y) + logging.info('final accuracy = %f', acc1) + assert(acc1 > 0.95) + + # pickle the model + smodel = pickle.dumps(model) + model2 = pickle.loads(smodel) + prob2 = model2.predict(val_dataiter) + assert np.sum(np.abs(prob - prob2)) == 0 + + # load model from checkpoint + model3 = mx.model.FeedForward.load(prefix, num_round) + prob3 = model3.predict(val_dataiter) + assert np.sum(np.abs(prob - prob3)) == 0 + + # save model explicitly + model.save(prefix, 128) + model4 = mx.model.FeedForward.load(prefix, 128) + prob4 = model4.predict(val_dataiter) + assert np.sum(np.abs(prob - prob4)) == 0 + + for i in range(num_round): + os.remove('%s-%04d.params' % (prefix, i + 1)) + os.remove('%s-symbol.json' % prefix) + os.remove('%s-0128.params' % prefix) + if __name__ == "__main__": test_mlp() diff --git a/tests/python/train/test_mlp_old.py b/tests/python/train/test_mlp_old.py deleted file mode 100644 index 651c85842a6a..000000000000 --- a/tests/python/train/test_mlp_old.py +++ /dev/null @@ -1,107 +0,0 @@ -# pylint: skip-file -import mxnet as mx -import numpy as np -import os, gzip -import pickle as pickle -from common import get_data - -def CalAcc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - -# symbol net -batch_size = 100 -data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) -act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) -act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) -softmax = mx.symbol.Softmax(data = fc3, name = 'sm') -args_list = softmax.list_arguments() -# infer shape -data_shape = (batch_size, 784) -arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.nd.empty(shape) for shape in arg_shapes] -grad_narrays = [mx.nd.empty(shape) for shape in arg_shapes] -inputs = dict(zip(args_list, arg_narrays)) -np.random.seed(0) -# set random weight -for name, narray in inputs.items(): - if "weight" in name: - narray[:] = np.random.uniform(-0.07, 0.07, narray.shape) - if "bias" in name: - narray[:] = 0.0 - -# bind executer -# TODO(bing): think of a better bind interface -executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays) -# update - -out_narray = executor.outputs[0] -grad_narray = mx.nd.empty(out_narray.shape) - -epoch = 9 -lr = 0.1 -wd = 0.0004 - -def Update(grad, weight): - weight[:] -= lr * grad / batch_size - -block = list(zip(grad_narrays, arg_narrays)) - -#check data -get_data.GetMNIST_ubyte() - -train_dataiter = mx.io.MNISTIter( - image="data/train-images-idx3-ubyte", - label="data/train-labels-idx1-ubyte", - input_shape=(784,), - batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) -val_dataiter = mx.io.MNISTIter( - image="data/t10k-images-idx3-ubyte", - label="data/t10k-labels-idx1-ubyte", - input_shape=(784,), - batch_size=batch_size, shuffle=True, flat=True, silent=False) - -def test_mlp(): - acc_train = 0. - acc_val = 0. - for i in range(epoch): - # train - print("Epoch %d" % i) - train_acc = 0.0 - val_acc = 0.0 - train_nbatch = 0 - val_nbatch = 0 - for data, label in train_dataiter: - label = label.asnumpy().flatten() - inputs["data"][:] = data - inputs["sm_label"][:] = label - executor.forward() - train_acc += CalAcc(out_narray.asnumpy(), label) - train_nbatch += 1 - grad_narray[:] = out_narray - executor.backward([grad_narray]) - - for grad, weight in block: - Update(grad, weight) - - # evaluate - for data, label in val_dataiter: - label = label.asnumpy().flatten() - inputs["data"][:] = data - executor.forward() - val_acc += CalAcc(out_narray.asnumpy(), label) - val_nbatch += 1 - acc_train = train_acc / train_nbatch - acc_val = val_acc / val_nbatch - print("Train Acc: ", train_acc / train_nbatch) - print("Valid Acc: ", val_acc / val_nbatch) - train_dataiter.reset() - val_dataiter.reset() - assert(acc_train > 0.98) - assert(acc_val > 0.97) - -if __name__ == "__main__": - test_mlp()