Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Bugfix of slice operator export (MXNet to ONNX) v2 #18535

Merged
merged 5 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def convert_string_to_list(string_val):

return result_list


def get_boolean_attribute_value(attrs, attr_name):
""" Helper function to convert a string version
of Boolean attributes to integer for ONNX.
Expand All @@ -126,21 +127,35 @@ def get_boolean_attribute_value(attrs, attr_name):
"""
return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0

def get_inputs(node, kwargs):

def get_inputs(node, kwargs, with_shapes=False):
"""Helper function to get inputs"""
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
index_lookup = kwargs["index_lookup"]
graph_shapes = kwargs["graph_shapes"]
inputs = node["inputs"]
attrs = node.get("attrs", {})

input_nodes = []
input_shapes = []
for ip in inputs:
input_node_id = index_lookup[ip[0]]
input_nodes.append(proc_nodes[input_node_id].name)
try:
# ip[1] defines which output index to use
input_nodes.append(proc_nodes[input_node_id].output[ip[1]])
except AttributeError:
# fallback to the name attribute as output if the output attribute does not exist (e.g. for data nodes)
input_nodes.append(proc_nodes[input_node_id].name)

input_shapes.append(graph_shapes.get(input_nodes[-1]))

if with_shapes:
return name, input_nodes, input_shapes, attrs

return name, input_nodes, attrs


def create_basic_op_node(op_name, node, kwargs):
"""Helper function to create a basic operator
node that doesn't contain op specific attrs"""
Expand All @@ -154,6 +169,7 @@ def create_basic_op_node(op_name, node, kwargs):
)
return [node]


@mx_op.register("null")
def convert_weights_and_inputs(node, **kwargs):
"""Helper function to convert weights and inputs.
Expand Down Expand Up @@ -1603,15 +1619,15 @@ def convert_slice_axis(node, **kwargs):
"""Map MXNet's slice_axis operator attributes to onnx's Slice operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True)

axes = int(attrs.get("axis"))
starts = int(attrs.get("begin"))
ends = attrs.get("end", None)
if not ends or ends == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
in_shape = kwargs['in_shape'][0]
in_shape = input_shapes[0]
ends = in_shape[axes]

export_nodes = []
Expand Down Expand Up @@ -1650,7 +1666,7 @@ def convert_slice_channel(node, **kwargs):
operator based on squeeze_axis attribute
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True)

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
Expand All @@ -1666,7 +1682,7 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
in_shape = kwargs.get('in_shape')[0]
in_shape = input_shapes[0]
split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
Expand Down
11 changes: 8 additions & 3 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,15 @@ def split_params(sym, params):
return arg_params, aux_params

@staticmethod
def get_outputs(sym, params, in_shape, in_label):
def get_outputs(sym, params, in_shape, in_label, verbose=True):
""" Infer output shapes and return dictionary of output name to shape

:param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
:param dic of (str, nd.NDArray) params:
:param list of tuple(int, ...) in_shape: list of all input shapes
:param in_label: name of label typically used in loss that may be left in graph. This name is
removed from list of inputs required by symbol
:param verbose: If false, info logging messages are deactivated
:return: dictionary of output name to shape
:rtype: dict of (str, tuple(int, ...))
"""
Expand All @@ -142,7 +143,8 @@ def get_outputs(sym, params, in_shape, in_label):
if name.endswith('_output'):
out_names.append(name[:-len('_output')])
else:
logging.info("output '%s' does not end with '_output'", name)
if verbose:
logging.info("output '%s' does not end with '_output'", name)
out_names.append(name)

assert len(out_shapes) == len(out_names)
Expand Down Expand Up @@ -203,8 +205,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
onnx_processed_outputs = []
index_lookup = []

# Determine output shape
# Determine output and internal shapes
graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label)
graph_shapes = MXNetGraph.get_outputs(sym.get_internals(), params, in_shape, output_label, verbose=False)

graph_input_idx = 0
for idx, node in enumerate(mx_graph):
Expand All @@ -230,6 +233,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
in_shape=in_shape[graph_input_idx],
in_type=in_type,
proc_nodes=all_processed_nodes,
graph_shapes=graph_shapes,
initializer=initializer,
index_lookup=index_lookup)
graph_input_idx += 1
Expand All @@ -244,6 +248,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
in_shape=in_shape,
in_type=in_type,
proc_nodes=all_processed_nodes,
graph_shapes=graph_shapes,
initializer=initializer,
index_lookup=index_lookup,
idx=idx
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/onnx/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mxnet import nd, sym
from mxnet.test_utils import set_default_context
from mxnet.gluon import nn
from mxnet.gluon import HybridBlock
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx

Expand Down Expand Up @@ -80,6 +81,16 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, rtol=1e-5)


class SplitConcatBlock(HybridBlock):
"""Block which creates two splits and later concatenates them"""
def __init__(self, name):
super(SplitConcatBlock, self).__init__(name)

def hybrid_forward(self, F, x):
splits = F.split(x, axis=1, num_outputs=2)
return F.concat(*splits)


class TestExport(unittest.TestCase):
""" Tests ONNX export.
"""
Expand Down Expand Up @@ -126,3 +137,17 @@ def test_onnx_export_extra_params(self):
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})

@with_seed()
def test_onnx_export_slice(self):
net = nn.HybridSequential(prefix='slice_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), nn.Dense(10))
_check_onnx_export(net)

@with_seed()
def test_onnx_export_slice_changing_shape(self):
net = nn.HybridSequential(prefix='slice_net_changing_shape')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"),
nn.Dense(50, activation='relu'), SplitConcatBlock("splitConcat2"), nn.Dense(10))
_check_onnx_export(net)