Skip to content

Commit

Permalink
[Numpy] Bugfix of slice operator export (MXNet to ONNX) v2 (apache#18535
Browse files Browse the repository at this point in the history
)

* fixed get_inputs() for onnx slice operator export

* added unit test for onnx slice operator export

* implement get_inputs with_shapes helper

* update slice ops to use with_shapes

* added verbose parameter for get_outputs()

Co-authored-by: Andrey Stotskiy <andrey.stotskiy@tevian.ru>
  • Loading branch information
2 people authored and bgawrych committed Jun 23, 2020
1 parent e1e4e2b commit f3e5d99
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
28 changes: 22 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def convert_string_to_list(string_val):

return result_list


def get_boolean_attribute_value(attrs, attr_name):
""" Helper function to convert a string version
of Boolean attributes to integer for ONNX.
Expand All @@ -126,21 +127,35 @@ def get_boolean_attribute_value(attrs, attr_name):
"""
return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0

def get_inputs(node, kwargs):

def get_inputs(node, kwargs, with_shapes=False):
"""Helper function to get inputs"""
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
index_lookup = kwargs["index_lookup"]
graph_shapes = kwargs["graph_shapes"]
inputs = node["inputs"]
attrs = node.get("attrs", {})

input_nodes = []
input_shapes = []
for ip in inputs:
input_node_id = index_lookup[ip[0]]
input_nodes.append(proc_nodes[input_node_id].name)
try:
# ip[1] defines which output index to use
input_nodes.append(proc_nodes[input_node_id].output[ip[1]])
except AttributeError:
# fallback to the name attribute as output if the output attribute does not exist (e.g. for data nodes)
input_nodes.append(proc_nodes[input_node_id].name)

input_shapes.append(graph_shapes.get(input_nodes[-1]))

if with_shapes:
return name, input_nodes, input_shapes, attrs

return name, input_nodes, attrs


def create_basic_op_node(op_name, node, kwargs):
"""Helper function to create a basic operator
node that doesn't contain op specific attrs"""
Expand All @@ -154,6 +169,7 @@ def create_basic_op_node(op_name, node, kwargs):
)
return [node]


@mx_op.register("null")
def convert_weights_and_inputs(node, **kwargs):
"""Helper function to convert weights and inputs.
Expand Down Expand Up @@ -1565,15 +1581,15 @@ def convert_slice_axis(node, **kwargs):
"""Map MXNet's slice_axis operator attributes to onnx's Slice operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True)

axes = int(attrs.get("axis"))
starts = int(attrs.get("begin"))
ends = attrs.get("end", None)
if not ends or ends == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
in_shape = kwargs['in_shape'][0]
in_shape = input_shapes[0]
ends = in_shape[axes]

export_nodes = []
Expand Down Expand Up @@ -1612,7 +1628,7 @@ def convert_slice_channel(node, **kwargs):
operator based on squeeze_axis attribute
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True)

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
Expand All @@ -1628,7 +1644,7 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
in_shape = kwargs.get('in_shape')[0]
in_shape = input_shapes[0]
split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
Expand Down
11 changes: 8 additions & 3 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,15 @@ def split_params(sym, params):
return arg_params, aux_params

@staticmethod
def get_outputs(sym, params, in_shape, in_label):
def get_outputs(sym, params, in_shape, in_label, verbose=True):
""" Infer output shapes and return dictionary of output name to shape
:param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
:param dic of (str, nd.NDArray) params:
:param list of tuple(int, ...) in_shape: list of all input shapes
:param in_label: name of label typically used in loss that may be left in graph. This name is
removed from list of inputs required by symbol
:param verbose: If false, info logging messages are deactivated
:return: dictionary of output name to shape
:rtype: dict of (str, tuple(int, ...))
"""
Expand All @@ -142,7 +143,8 @@ def get_outputs(sym, params, in_shape, in_label):
if name.endswith('_output'):
out_names.append(name[:-len('_output')])
else:
logging.info("output '%s' does not end with '_output'", name)
if verbose:
logging.info("output '%s' does not end with '_output'", name)
out_names.append(name)

assert len(out_shapes) == len(out_names)
Expand Down Expand Up @@ -203,8 +205,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
onnx_processed_outputs = []
index_lookup = []

# Determine output shape
# Determine output and internal shapes
graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label)
graph_shapes = MXNetGraph.get_outputs(sym.get_internals(), params, in_shape, output_label, verbose=False)

graph_input_idx = 0
for idx, node in enumerate(mx_graph):
Expand All @@ -230,6 +233,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
in_shape=in_shape[graph_input_idx],
in_type=in_type,
proc_nodes=all_processed_nodes,
graph_shapes=graph_shapes,
initializer=initializer,
index_lookup=index_lookup)
graph_input_idx += 1
Expand All @@ -244,6 +248,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
in_shape=in_shape,
in_type=in_type,
proc_nodes=all_processed_nodes,
graph_shapes=graph_shapes,
initializer=initializer,
index_lookup=index_lookup,
idx=idx
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/onnx/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mxnet import nd, sym
from mxnet.test_utils import set_default_context
from mxnet.gluon import nn
from mxnet.gluon import HybridBlock
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx

Expand Down Expand Up @@ -80,6 +81,16 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, rtol=1e-5)


class SplitConcatBlock(HybridBlock):
"""Block which creates two splits and later concatenates them"""
def __init__(self, name):
super(SplitConcatBlock, self).__init__(name)

def hybrid_forward(self, F, x):
splits = F.split(x, axis=1, num_outputs=2)
return F.concat(*splits)


class TestExport(unittest.TestCase):
""" Tests ONNX export.
"""
Expand Down Expand Up @@ -126,3 +137,17 @@ def test_onnx_export_extra_params(self):
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})

@with_seed()
def test_onnx_export_slice(self):
net = nn.HybridSequential(prefix='slice_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), nn.Dense(10))
_check_onnx_export(net)

@with_seed()
def test_onnx_export_slice_changing_shape(self):
net = nn.HybridSequential(prefix='slice_net_changing_shape')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"),
nn.Dense(50, activation='relu'), SplitConcatBlock("splitConcat2"), nn.Dense(10))
_check_onnx_export(net)

0 comments on commit f3e5d99

Please sign in to comment.