Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
changing argument names and making it similar to module
Browse files Browse the repository at this point in the history
  • Loading branch information
Roshrini committed Nov 1, 2018
1 parent ac2d76a commit d20dfe1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
14 changes: 7 additions & 7 deletions python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ._export_helper import load_module


def export_model(sym, params, input_shape, input_type=np.float32, output_label=None, output_shape=None,
def export_model(sym, params, input_shape, input_type=np.float32, label_names=None, label_shapes=None,
onnx_file_path='model.onnx', verbose=False):
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths as input.
Expand All @@ -49,10 +49,10 @@ def export_model(sym, params, input_shape, input_type=np.float32, output_label=N
Input shape of the model e.g [(1,3,224,224)]
input_type : data type
Input data type e.g. np.float32
output_label : List of str
Optional list of output node labels
output_shape : List of tuple
Input shape of the model e.g [(1,3,224,224)]
label_names : List of str
Optional list of label e.g. ['regression_label']
label_shapes : List of tuple
Optional a list of (name, shape) pairs e.g [('regression_label', (1,3,224,224))]
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
Expand All @@ -79,11 +79,11 @@ def export_model(sym, params, input_shape, input_type=np.float32, output_label=N
sym_obj, params_obj = load_module(sym, params)
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
output_label, output_shape, verbose=verbose)
label_names, label_shapes, verbose=verbose)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
output_label, output_shape, verbose=verbose)
label_names, label_shapes, verbose=verbose)
else:
raise ValueError("Input sym and params should either be files or objects")

Expand Down
37 changes: 25 additions & 12 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def forward_pass(inputs, sym, arg_params, aux_params, label_name):
test_mod.forward(io.DataBatch(data_forward))
result = [i.asnumpy().shape for i in test_mod.get_outputs()]

return result
result_shape = []
for idx, label in enumerate(label_name):
result_shape.append((label, result[idx]))

return result_shape


@staticmethod
Expand Down Expand Up @@ -193,7 +197,17 @@ def convert_weights_to_numpy(weights_dict):
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])

def create_onnx_graph_proto(self, sym, params, in_shape, in_type, out_label=None, out_shape=None, verbose=False):
@staticmethod
def verify_provided_labels(data_names, data_shapes, name, throw):
"""Check that input labels matches input data shape."""
actual = [x[0] for x in data_shapes]
if sorted(data_names) != sorted(actual):
msg = "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)" % (
name, name, str(data_shapes), str(data_names))
if throw:
raise ValueError(msg)

def create_onnx_graph_proto(self, sym, params, in_shape, in_type, label_names=None, label_shapes=None, verbose=False):
"""Convert MXNet graph to ONNX graph
Parameters
Expand Down Expand Up @@ -233,16 +247,14 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, out_label=None
output_suffix = '_output'
output_names = [o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]

if not out_label:
if not label_names:
label_names = [output_name + '_label' for output_name in output_names]
else:
label_names = out_label

# Determine output shape
if not out_shape:
label_shapes = MXNetGraph.infer_output_shape(sym, params, in_shape, label_names)
if not label_shapes:
label_shapes = MXNetGraph.infer_output_shape(sym, params, in_shape, label_names)
else:
label_shapes = out_shape
MXNetGraph.verify_provided_labels(label_names, label_shapes, 'label', True)

graph_inputs = sym.list_inputs()

Expand Down Expand Up @@ -270,7 +282,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, out_label=None
# Handling graph input
# Skipping output_label node, as this node is not part of graph
# Refer "output_label" assignment above for more details.
if name in label_names and name not in graph_inputs:
if name in label_names:
continue
converted = MXNetGraph.convert_layer(
node,
Expand Down Expand Up @@ -308,22 +320,23 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, out_label=None
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
if idx == (len(mx_graph) - 1) or converted_node.name in output_names:
if converted_node.name in output_names:
label_shape = [i[1] for i in label_shapes if converted_node.name + "_label" == i[0]]
# If converted node doesnt have name, use it from output field
if not converted_node.name:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=label_shapes[0]
shape=label_shape[0]
)
)
else:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=label_shapes[0]
shape=label_shape[0]
)
)
if verbose:
Expand Down

0 comments on commit d20dfe1

Please sign in to comment.