Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-897] ONNX import/export: Size #13112

Merged
merged 1 commit into from
Dec 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,3 +1645,11 @@ def convert_logical_not(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Not', node, kwargs)


@mx_op.register("size_array")
def convert_size(node, **kwargs):
"""Map MXNet's size_array operator attributes to onnx's Size operator
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
Expand Down Expand Up @@ -139,6 +139,7 @@
'Softplus' : softplus,
'Tan' : _tan,
'Shape' : shape,
'Size' : size,
'Gather' : gather,
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling,
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def shape(attrs, inputs, proto_obj):
"""Returns shape of input array."""
return 'shape_array', attrs, inputs

def size(attrs, inputs, proto_obj):
"""Returns array containing size of data."""
return "size_array", attrs, inputs

def reduce_l2(attrs, inputs, proto_obj):
"""Reduce input tensor by l2 normalization."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
'test_shape'
'test_shape',
'test_size'
]

BASIC_MODEL_TESTS = [
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
'test_depthtospace'
'test_depthtospace',
'test_size'
]

BASIC_MODEL_TESTS = [
Expand Down