Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ONNX import/export: Upsampling #15994

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,6 @@ def convert_topk(node, **kwargs):
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', '-1'))
k = int(attrs.get('k', '1'))
ret_type = attrs.get('ret_typ')
Expand All @@ -2080,3 +2079,26 @@ def convert_topk(node, **kwargs):
)

return [topk_node]


@mx_op.register("UpSampling")
def convert_upsample(node, **kwargs):
"""Map MXNet's UpSampling operator attributes to onnx's Upsample operator
Copy link

@dungmn dungmn Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing the end """ at line 2086?

sample_type = attrs.get('sample_type', 'nearest')
sample_type = 'linear' if sample_type == 'bilinear' else sample_type
scale = convert_string_to_list(attrs.get('scale'))
scaleh = float(scale[0])
scalew = float(scale[0])
if len(scale) > 1:
scalew = float(scale[1])
scale = [1.0, 1.0, scaleh, scalew]

node = onnx.helper.make_node(
'Upsample',
input_nodes,
[name],
scales=scale,
mode=sample_type,
name=name
)
return [node]
7 changes: 4 additions & 3 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat, hardmax, topk
from ._op_translations import concat, hardmax, topk, upsampling
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
Expand Down Expand Up @@ -148,5 +148,6 @@
'SpaceToDepth' : spacetodepth,
'Hardmax' : hardmax,
'LpNormalization' : lpnormalization,
'TopK' : topk
}
'TopK' : topk,
'Upsample' : upsampling,
}
21 changes: 21 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,24 @@ def topk(attrs, inputs, proto_obj):
{'ret_typ': 'both',
'dtype': 'int64'})
return 'topk', new_attrs, inputs


def upsampling(attrs, inputs, proto_obj):
"""Rearranges blocks of spatial data into depth."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'scales': 'scale',
'mode': 'sample_type'})
sample_type = new_attrs.get('sample_type', 'nearest')
if sample_type != 'nearest':
raise NotImplementedError("Operator {} in ONNX supports 'linear' mode "
"for linear, bilinear, trilinear etc. There is no "
"way to distinguish these so far. Therefore, supporting "
"import of only nearest neighbor upsampling for now. "
"/~https://github.com/onnx/onnx/issues/1774. "
"Use contrib.BilinearResize2D for bilinear mode."
.format('UpSample'))

scale = tuple(new_attrs.get('scale'))[2:]
scale = tuple([int(s) for s in scale])
mx_op = symbol.UpSampling(inputs[0], scale=scale, sample_type=sample_type)

return mx_op, new_attrs, inputs
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
'test_softplus',
'test_reduce_',
'test_split_equal',
'test_top_k'
'test_top_k',
'test_upsample_n'
],
'import': ['test_gather',
'test_softsign',
Expand Down