diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3baf10a10d39..d24865d9dcb1 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -586,6 +586,7 @@ def convert_pooling(node, **kwargs): pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None global_pool = get_boolean_attribute_value(attrs, "global_pool") + p_value = attrs.get('p_value', 'None') pooling_convention = attrs.get('pooling_convention', 'valid') @@ -598,26 +599,51 @@ def convert_pooling(node, **kwargs): pad_dims = list(parse_helper(attrs, "pad", [0, 0])) pad_dims = pad_dims + pad_dims - pool_types = {"max": "MaxPool", "avg": "AveragePool"} - global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"} + pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"} + global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool", + "lp": "GlobalLpPool"} + + if pool_type == 'lp' and p_value == 'None': + raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool') if global_pool: - node = onnx.helper.make_node( - global_pool_types[pool_type], - input_nodes, # input - [name], - name=name - ) + if pool_type == 'lp': + node = onnx.helper.make_node( + global_pool_types[pool_type], + input_nodes, # input + [name], + p=int(p_value), + name=name + ) + else: + node = onnx.helper.make_node( + global_pool_types[pool_type], + input_nodes, # input + [name], + name=name + ) else: - node = onnx.helper.make_node( - pool_types[pool_type], - input_nodes, # input - [name], - kernel_shape=kernel, - pads=pad_dims, - strides=stride, - name=name - ) + if pool_type == 'lp': + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + p=int(p_value), + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) + else: + node = onnx.helper.make_node( + pool_types[pool_type], + input_nodes, # input + [name], + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) return [node] @@ -1689,3 +1715,26 @@ def convert_logsoftmax(node, **kwargs): name=name ) return [node] + + +@mx_op.register("_sample_multinomial") +def convert_multinomial(node, **kwargs): + """Map MXNet's multinomial operator attributes to onnx's + Multinomial operator and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get("dtype", 'int32'))] + sample_size = convert_string_to_list(attrs.get("shape", '1')) + if len(sample_size) < 2: + sample_size = sample_size[-1] + else: + raise AttributeError("ONNX currently supports integer sample_size only") + node = onnx.helper.make_node( + "Multinomial", + input_nodes, + [name], + dtype=dtype, + sample_size=sample_size, + name=name, + ) + return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index 5b33f9faac11..2a668dc84be8 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -18,7 +18,7 @@ # coding: utf-8_ # pylint: disable=invalid-name """Operator attributes conversion""" -from ._op_translations import identity, random_uniform, random_normal +from ._op_translations import identity, random_uniform, random_normal, sample_multinomial from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size @@ -48,6 +48,7 @@ 'RandomNormal' : random_normal, 'RandomUniformLike' : random_uniform, 'RandomNormalLike' : random_normal, + 'Multinomial' : sample_multinomial, # Arithmetic Operators 'Add' : add, 'Sub' : subtract, diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index ce0e0e51ef79..a061a7ef0027 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -38,6 +38,19 @@ def random_normal(attrs, inputs, proto_obj): new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'}) return 'random_uniform', new_attr, inputs +def sample_multinomial(attrs, inputs, proto_obj): + """Draw random samples from a multinomial distribution.""" + try: + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - /~https://github.com/onnx/onnx") + new_attrs = translation_utils._remove_attributes(attrs, ['seed']) + new_attrs = translation_utils._fix_attribute_names(new_attrs, {'sample_size': 'shape'}) + new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))] + return 'sample_multinomial', new_attrs, inputs + + # Arithmetic Operations def add(attrs, inputs, proto_obj): """Adding two tensors""" @@ -382,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj): 'kernel': (1, 1), 'pool_type': 'lp', 'p_value': p_value}) + new_attrs = translation_utils._remove_attributes(new_attrs, ['p']) return 'Pooling', new_attrs, inputs def linalg_gemm(attrs, inputs, proto_obj): @@ -671,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj): new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape': 'kernel', 'strides': 'stride', - 'pads': 'pad', - 'p_value': p_value + 'pads': 'pad' }) + new_attrs = translation_utils._remove_attributes(new_attrs, ['p']) new_attrs = translation_utils._add_extra_attributes(new_attrs, - {'pooling_convention': 'valid' + {'pooling_convention': 'valid', + 'p_value': p_value }) new_op = translation_utils._fix_pooling('lp', inputs, new_attrs) return new_op, new_attrs, inputs diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index f63c1e9e8e62..6fd52665ca31 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr): stride = new_attr.get('stride') kernel = new_attr.get('kernel') padding = new_attr.get('pad') + p_value = new_attr.get('p_value') # Adding default stride. if stride is None: @@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr): new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width) # Apply pooling without pads. - new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel) + if pool_type == 'lp': + new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel, p_value=p_value) + else: + new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel) return new_pooling_op def _fix_bias(op_name, attrs, num_inputs): diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 6a189b62492d..64aaab0f6d48 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -79,7 +79,6 @@ 'test_softplus' ], 'import': ['test_gather', - 'test_global_lppooling', 'test_softsign', 'test_reduce_', 'test_mean', @@ -89,7 +88,6 @@ 'test_averagepool_2d_precomputed_strides', 'test_averagepool_2d_strides', 'test_averagepool_3d', - 'test_LpPool_', 'test_split_equal', 'test_hardmax' ], diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 07ae866b96cf..6a0f8bcd73c2 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -56,6 +56,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32): return np.random.choice(a=[False, True], size=shape).astype(np.float32) +def _fix_attributes(attrs, attribute_mapping): + new_attrs = attrs + attr_modify = attribute_mapping.get('modify', {}) + for k, v in attr_modify.items(): + new_attrs[v] = new_attrs.pop(k, None) + + attr_add = attribute_mapping.get('add', {}) + for k, v in attr_add.items(): + new_attrs[k] = v + + attr_remove = attribute_mapping.get('remove', []) + for k in attr_remove: + if k in new_attrs: + del new_attrs[k] + + return new_attrs + + def forward_pass(sym, arg, aux, data_names, input_data): """ Perform forward pass on given data :param sym: Symbol @@ -118,7 +136,7 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att return model for test in test_cases: - test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test + test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test with self.subTest(test_name): names, input_tensors, inputsym = get_input_tensors(inputs) test_op = mxnet_op(*inputsym, **attrs) @@ -131,33 +149,66 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att onnx_name + ".onnx") onnxmodel = load_model(onnxmodelfile) else: - onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs) + onnx_attrs = _fix_attributes(attrs, fix_attrs) + onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs) bkd_rep = backend.prepare(onnxmodel, operation='export') output = bkd_rep.run(inputs) - npt.assert_almost_equal(output[0], mxnet_output) + if check_value: + npt.assert_almost_equal(output[0], mxnet_output) + + if check_shape: + npt.assert_equal(output[0].shape, outputshape) -# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False) +# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False, +# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name}, +# 'remove': [attr_name], +# 'add': {attr_name: value}, +# check_value=True/False, check_shape=True/False) test_cases = [ - ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), - ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), - ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), + ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), + ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), + ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), ("test_and", mx.sym.broadcast_logical_and, "And", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False), ("test_xor", mx.sym.broadcast_logical_xor, "Xor", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False), ("test_or", mx.sym.broadcast_logical_or, "Or", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), - ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), - ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False), + ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False), + ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, False), ("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))], - {'block_size': 2}, False), + {'block_size': 2}, False, {}, True, False), ("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)], - {'ignore_label': 0, 'use_ignore': False}, True), - ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)], - {'num_hidden': 4, 'name': 'FC'}, True) + {'ignore_label': 0, 'use_ignore': False}, True, {}, True, False), + ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)], + {'num_hidden': 4, 'name': 'FC'}, True, {}, True, False), + ("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))], + {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False, + {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'}, + 'remove': ['pool_type']}, True, False), + ("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))], + {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp'}, False, + {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'}, + 'remove': ['pool_type']}, True, False), + ("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))], + {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp', 'global_pool': True}, False, + {'modify': {'p_value': 'p'}, + 'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False), + ("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))], + {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp', 'global_pool': True}, False, + {'modify': {'p_value': 'p'}, + 'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False), + + # since results would be random, checking for shape alone + ("test_multinomial", mx.sym.sample_multinomial, "Multinomial", + [np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")], + {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True) ] if __name__ == '__main__':