diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 03fa812f3200..d78d7e59a529 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -79,6 +79,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx): self._aux_dict = None self._output_dict = None self._monitor_callback = None + self._monitor_all = None self._ctx = copy.deepcopy(ctx) self._grad_req = copy.deepcopy(grad_req) self._group2ctx = copy.deepcopy(group2ctx) @@ -253,6 +254,7 @@ def set_monitor_callback(self, callback, monitor_all=False): """ cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) + self._monitor_all = monitor_all check_call(_LIB.MXExecutorSetMonitorCallbackEX( self.handle, self._monitor_callback, @@ -477,6 +479,13 @@ def reshape(self, partial_shaping=False, allow_up_sizing=False, **kwargs): executor.arg_arrays = arg_arrays executor.grad_arrays = grad_arrays executor.aux_arrays = aux_arrays + if (self._monitor_callback is not None) and (self._monitor_all is not None): + # rebind callback to the new executor if the callback is valid + check_call(_LIB.MXExecutorSetMonitorCallbackEX( + handle, + self._monitor_callback, + None, + ctypes.c_int(self._monitor_all))) return executor def debug_str(self): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b22dc7bc156c..e6db0e9fc864 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8364,6 +8364,59 @@ def get_output_names_callback(name, arr): check_name(us_sym, ['data', 'pooling_data', 'pooling_output']) del os.environ['MXNET_SUBGRAPH_BACKEND'] +@with_seed() +def test_monitor_with_variable_input_shape(): + output = {} + + def get_output_min_callback(name, arr): + name = py_str(name) + handle = ctypes.cast(arr, NDArrayHandle) + arr = NDArray(handle, writable=False) + min_val = mx.ndarray.min(arr).asscalar() + if name in output: + output[name] = min(output[name], min_val) + else: + output[name] = min_val + + def check_result(output, names): + assert len(output) > 0 + for k, v in output.items(): + assert k in names + assert v is not None + + is_windows = sys.platform.startswith('win') + if (is_windows): + # Windows doesn't support set environment variable on the fly, so disable it for now + pass + else: + # Disable subgraph in case subgraph will replace symbol + os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE" + + batch_size = 1 + op_name = 'conv' + dshape = (batch_size, 3, 10, 10) + data = mx.sym.Variable('data', shape=dshape) + sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name) + + mod = mx.module.Module(symbol=sym, label_names=None) + mod.bind(for_training=False, data_shapes=[('data', dshape)]) + mod.init_params() + mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True) + + new_dshape = dshape[:-1] + (dshape[-1] + 4,) + new_data = mx.nd.random.uniform(shape=new_dshape) + new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size) + new_data = DummyIter(new_data) + + for batch in new_data: + mod.forward(data_batch=batch, is_train=False) + mx.nd.waitall() + break + + name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output'] + check_result(output, name_list) + del os.environ['MXNET_SUBGRAPH_BACKEND'] + @with_seed() @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at /~https://github.com/apache/incubator-mxnet/issues/13915") def test_activation():