Skip to content

Commit

Permalink
fix for params with no dims in onnx (apache#13413)
Browse files Browse the repository at this point in the history
* fix for params with no dims

* fix

* fix

* retrigger build

* test added

* retrigger CI

* retrigger ci
  • Loading branch information
Roshrini authored and haohuw committed Jun 23, 2019
1 parent a8a2117 commit 689cc98
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=invalid-name,too-many-locals,no-self-use
""" Support import export formats."""
from __future__ import absolute_import as _abs
import numpy as np
from .... import symbol
from .... import ndarray as nd
from ....base import string_types
Expand Down Expand Up @@ -87,7 +88,7 @@ def from_onnx(self, graph):
params : dict
A dict of name: nd.array pairs, used as pretrained weights
"""
#get input, output shapes
# get input, output shapes
self.model_metadata = self.get_graph_metadata(graph)
# parse network inputs, aka parameters
for init_tensor in graph.initializer:
Expand Down Expand Up @@ -196,7 +197,11 @@ def _parse_array(self, tensor_proto):
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - /~https://github.com/onnx/onnx")
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
if len(tuple(tensor_proto.dims)) > 0:
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
else:
# If onnx's params are scalar values without dims mentioned.
np_array = np.array([to_array(tensor_proto)])
return nd.array(np_array)

def _parse_attr(self, attr_proto):
Expand Down
11 changes: 11 additions & 0 deletions tests/python-pytest/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'https://s3.amazonaws.com/download.onnx/models/opset_8/inception_v2.tar.gz'
}

test_model_path = "https://s3.amazonaws.com/onnx-mxnet/test_model.onnx"

def get_test_files(name):
"""Extract tar file and returns model path and input, output data"""
Expand Down Expand Up @@ -152,6 +153,16 @@ def get_model_results(modelpath):

logging.info(model_name + " conversion successful")

def test_nodims_import(self):
# Download test model without dims mentioned in params
test_model = download(test_model_path, dirname=CURR_PATH.__str__())
input_data = np.array([0.2, 0.5])
nd_data = mx.nd.array(input_data).expand_dims(0)
sym, arg_params, aux_params = onnx_mxnet.import_model(test_model)
model_metadata = onnx_mxnet.get_model_metadata(test_model)
input_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
output_data = forward_pass(sym, arg_params, aux_params, input_names, nd_data)
assert(output_data.shape == (1,1))

# test_case = ("model name", input shape, output shape)
test_cases = [
Expand Down

0 comments on commit 689cc98

Please sign in to comment.