From 372cb0a20c244b84940d928c1e5d61d4c2e94853 Mon Sep 17 00:00:00 2001 From: vandanavk Date: Tue, 20 Nov 2018 14:08:37 -0800 Subject: [PATCH] ONNX export: Add Flatten before Gemm --- .../contrib/onnx/mx2onnx/_op_translations.py | 11 ++++++ .../mxnet/contrib/onnx/mx2onnx/export_onnx.py | 39 ++++++++++--------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 73ca07be76ee..65ca2c25b2f2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs): fcnode = [] + op_name = "flatten_" + str(kwargs["idx"]) + flatten_node = onnx.helper.make_node( + 'Flatten', + inputs=[input_nodes[0]], + outputs=[op_name], + name=op_name + ) + + input_nodes[0] = op_name + fcnode.append(flatten_node) + if no_bias: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] bias_name = "bias" + str(kwargs["idx"]) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index b02d970f9c2d..c1f327e20ba2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -231,6 +231,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # Determine output shape output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label) + output_suffix = '_output' + output_names = [ + o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)] + weights = MXNetGraph.convert_weights_to_numpy(params) mx_graph = json.loads(sym.tojson())["nodes"] @@ -294,26 +298,25 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # If converted node is NodeProto, add it in processed nodes list elif isinstance(converted_node, NodeProto): onnx_processed_nodes.append(converted_node) - if idx == (len(mx_graph) - 1): - # If converted node doesnt have name, use it from output field - if not converted_node.name: - onnx_processed_outputs.append( - make_tensor_value_info( - name=converted_node.output[0], - elem_type=in_type, - shape=output_shape - ) + # If converted node doesnt have name, use it from output field + if not converted_node.name and idx == (len(mx_graph) - 1): + onnx_processed_outputs.append( + make_tensor_value_info( + name=converted_node.output[0], + elem_type=in_type, + shape=output_shape ) - else: - onnx_processed_outputs.append( - make_tensor_value_info( - name=converted_node.name, - elem_type=in_type, - shape=output_shape - ) + ) + elif converted_node.name in output_names: + onnx_processed_outputs.append( + make_tensor_value_info( + name=converted_node.name, + elem_type=in_type, + shape=output_shape ) - if verbose: - logging.info("Output node is: %s", converted_node.name) + ) + if verbose: + logging.info("Output node is: %s", converted_node.name) elif isinstance(converted_node, TensorProto): raise ValueError("Did not expect TensorProto") else: