From 46cbd0f860eac39dc82e87bee21c60395601349a Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Sat, 24 Aug 2019 12:11:23 -0700 Subject: [PATCH 1/2] ONNX import/export: Upsampling --- .../contrib/onnx/mx2onnx/_op_translations.py | 27 +++++++++++++++++++ .../contrib/onnx/onnx2mx/_import_helper.py | 5 ++-- .../contrib/onnx/onnx2mx/_op_translations.py | 21 +++++++++++++++ tests/python-pytest/onnx/test_cases.py | 3 ++- 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 35f4ff451cdb..098b8e5b9278 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2047,3 +2047,30 @@ def convert_broadcast_to(node, **kwargs): ) return [tensor_node, expand_node] + + +@mx_op.register("UpSampling") +def convert_upsample(node, **kwargs): + """Map MXNet's UpSampling operator attributes to onnx's Upsample operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + sample_type = attrs.get('sample_type', 'nearest') + sample_type = 'linear' if sample_type == 'bilinear' else sample_type + scale = convert_string_to_list(attrs.get('scale')) + scaleh = float(scale[0]) + scalew = float(scale[0]) + if len(scale) > 1: + scalew = float(scale[1]) + scale = [1.0, 1.0, scaleh, scalew] + + node = onnx.helper.make_node( + 'Upsample', + input_nodes, + [name], + scales=scale, + mode=sample_type, + name=name + ) + return [node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index cf95bfef09a3..939027b12d51 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat, hardmax +from ._op_translations import concat, hardmax, upsampling from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -147,5 +147,6 @@ 'DepthToSpace' : depthtospace, 'SpaceToDepth' : spacetodepth, 'Hardmax' : hardmax, - 'LpNormalization' : lpnormalization + 'LpNormalization' : lpnormalization, + 'Upsample' : upsampling } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 734b438581a5..4c5b8874d241 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -779,3 +779,24 @@ def lpnormalization(attrs, inputs, proto_obj): axis = int(attrs.get("axis", -1)) new_attrs.update(axis=axis) return 'norm', new_attrs, inputs + + +def upsampling(attrs, inputs, proto_obj): + """Rearranges blocks of spatial data into depth.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale', + 'mode': 'sample_type'}) + sample_type = new_attrs.get('sample_type', 'nearest') + if sample_type != 'nearest': + raise NotImplementedError("Operator {} in ONNX supports 'linear' mode " + "for linear, bilinear, trilinear etc. There is no " + "way to distinguish these so far. Therefore, supporting " + "import of only nearest neighbor upsampling for now. " + "/~https://github.com/onnx/onnx/issues/1774. " + "Use contrib.BilinearResize2D for bilinear mode." + .format('UpSample')) + + scale = tuple(new_attrs.get('scale'))[2:] + scale = tuple([int(s) for s in scale]) + mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type) + + return mx_op, new_attrs, inputs diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 89b60d15e84f..90136bb01ba6 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -78,7 +78,8 @@ 'test_max_', 'test_softplus', 'test_reduce_', - 'test_split_equal' + 'test_split_equal', + 'test_upsample_n' ], 'import': ['test_gather', 'test_softsign', From 36ecf8dff0f02be1624def320ec69cfaaf5b0897 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Tue, 27 Aug 2019 10:34:59 -0700 Subject: [PATCH 2/2] Re-trigger CI