From 67b8ea10b6c96217e60ea6f855dc8b0e895f5c10 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Wed, 20 Feb 2019 16:07:54 -0800 Subject: [PATCH] fix update params --- python/mxnet/model.py | 6 ++++-- tests/python/unittest/test_module.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index c08077cc65f4..efb51096c368 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -181,8 +181,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, w, g = p updates[k].append((index*num_device+k, g, w)) for dev_updates in updates: - i, w, g = zip(*dev_updates) - updater(i, w, g) + # update params if param_arrays and grad_arrays are not empty + if dev_updates: + i, w, g = zip(*dev_updates) + updater(i, w, g) def _multiple_callbacks(callbacks, *args, **kwargs): diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index ae38a2297ded..36c1993bf0ff 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -917,6 +917,20 @@ def sym_gen(_): assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size) +def test_module_update_no_pragram(): + # test module to do update on layers without params + data_shape = (10, 10) + data = mx.sym.Variable('data') + out = mx.sym.Dropout(data, 0.5) + mod = mx.mod.Module(out) + mod.bind(data_shapes=[('data', data_shape)]) + mod.init_params() + mod.init_optimizer() + data_batch = mx.io.DataBatch([nd.ones(data_shape)]) + mod.forward_backward(data_batch) + mod.update() + assert(mod.get_outputs()[0].shape == data_shape) + if __name__ == '__main__': import nose nose.runmodule()