diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index d9c3d8ecf7ab..f2759f38cd03 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -575,6 +575,7 @@ def convert_pooling(node, **kwargs): pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None global_pool = get_boolean_attribute_value(attrs, "global_pool") + p_value = int(attrs.get('p_value', '2')) pooling_convention = attrs.get('pooling_convention', 'valid') @@ -587,26 +588,48 @@ def convert_pooling(node, **kwargs): pad_dims = list(parse_helper(attrs, "pad", [0, 0])) pad_dims = pad_dims + pad_dims - pool_types = {"max": "MaxPool", "avg": "AveragePool"} - global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"} + pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"} + global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool", + "lp": "GlobalLpPool"} if global_pool: - node = onnx.helper.make_node( - global_pool_types[pool_type], - input_nodes, # input - [name], - name=name - ) + if pool_type == 'lp': + node = onnx.helper.make_node( + global_pool_types[pool_type], + input_nodes, # input + [name], + p=p_value, + name=name + ) + else: + node = onnx.helper.make_node( + global_pool_types[pool_type], + input_nodes, # input + [name], + name=name + ) else: - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name - ) + if pool_type == 'lp': + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + p=p_value, + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) + else: + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 8431c6b1692a..e653fd400a3d 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -395,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj): 'kernel': (1, 1), 'pool_type': 'lp', 'p_value': p_value}) + new_attrs = translation_utils._remove_attributes(new_attrs, ['p']) return 'Pooling', new_attrs, inputs def linalg_gemm(attrs, inputs, proto_obj): @@ -684,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj): new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape': 'kernel', 'strides': 'stride', - 'pads': 'pad', - 'p_value': p_value + 'pads': 'pad' }) + new_attrs = translation_utils._remove_attributes(new_attrs, ['p']) new_attrs = translation_utils._add_extra_attributes(new_attrs, - {'pooling_convention': 'valid' + {'pooling_convention': 'valid', + 'p_value': p_value }) new_op = translation_utils._fix_pooling('lp', inputs, new_attrs) return new_op, new_attrs, inputs diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index f63c1e9e8e62..6fd52665ca31 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr): stride = new_attr.get('stride') kernel = new_attr.get('kernel') padding = new_attr.get('pad') + p_value = new_attr.get('p_value') # Adding default stride. if stride is None: @@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr): new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width) # Apply pooling without pads. - new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel) + if pool_type == 'lp': + new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel, p_value=p_value) + else: + new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel) return new_pooling_op def _fix_bias(op_name, attrs, num_inputs): diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 42d6f34b9875..3647f1a1ca2e 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -98,8 +98,11 @@ def forward_pass(sym, arg, aux, data_names, input_data): # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) - mod.set_params(arg_params=arg, aux_params=aux, - allow_missing=True, allow_extra=True) + if not arg and not aux: + mod.init_params() + else: + mod.set_params(arg_params=arg, aux_params=aux, + allow_missing=True, allow_extra=True) # run inference batch = namedtuple('Batch', ['data']) mod.forward(batch([mx.nd.array(input_data)]), is_train=False) @@ -345,6 +348,43 @@ def test_ops(op_name, inputs, input_tensors, numpy_op): np.logical_not(input_data[0]).astype(np.float32)) +@with_seed() +def testLpPooling(): + def test_pooling(opname, data, attrs, p): + input1 = np.random.rand(*data).astype("float32") + inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=data)] + sym = mx.sym.Pooling(mx.sym.Variable('input1'), pool_type='lp', p_value=p, **attrs) + lppool_output = forward_pass(sym, None, None, ['input1'], input1) + + lppool_op_tensor = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(lppool_output))] + + if attrs.get('global_pool', False): + lppool_node = [helper.make_node(opname, ["input1"], ["output"], p=p)] + else: + lppool_node = [helper.make_node(opname, ["input1"], ["output"], p=p, **attrs)] + + lppool_graph = helper.make_graph(lppool_node, + opname+"_test", + inputs, + lppool_op_tensor) + + lppool_model = helper.make_model(lppool_graph) + + bkd_rep = backend.prepare(lppool_model) + output = bkd_rep.run([input1]) + + npt.assert_almost_equal(output[0], lppool_output) + + ip = (2, 3, 20, 20) + kernel = (4, 5) + pad = (0, 0) + stride = (1, 1) + + for p_value in range(1, 3): + test_pooling('LpPool', ip, {'kernel': kernel, 'stride': stride, 'pad': pad}, p=p_value) + test_pooling('GlobalLpPool', ip, {'kernel': kernel, 'stride': stride, 'pad': pad, 'global_pool': True}, p=p_value) + + def _assert_sym_equal(lhs, rhs): assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index 04cfe93b7ed7..0772c6e56f68 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -45,7 +45,6 @@ 'test_transpose', 'test_globalmaxpool', 'test_globalaveragepool', - 'test_global_lppooling', 'test_slice_cpu', 'test_slice_neg', 'test_reciprocal', @@ -77,7 +76,6 @@ 'test_averagepool_2d_precomputed_strides', 'test_averagepool_2d_strides', 'test_averagepool_3d', - 'test_LpPool_', 'test_cast', 'test_instancenorm', #pytorch operator tests