From 8eb1ba5afa006a5b185b1a64f060f77f531acf29 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Thu, 3 Jan 2019 11:19:13 -0800 Subject: [PATCH] test added --- tests/python-pytest/onnx/test_models.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/python-pytest/onnx/test_models.py b/tests/python-pytest/onnx/test_models.py index 12bc2711190d..f85786141d6e 100644 --- a/tests/python-pytest/onnx/test_models.py +++ b/tests/python-pytest/onnx/test_models.py @@ -51,6 +51,7 @@ 'https://s3.amazonaws.com/download.onnx/models/opset_8/inception_v2.tar.gz' } +test_model_path = "https://s3.amazonaws.com/onnx-mxnet/test_model.onnx" def get_test_files(name): """Extract tar file and returns model path and input, output data""" @@ -152,6 +153,16 @@ def get_model_results(modelpath): logging.info(model_name + " conversion successful") + def test_nodims_import(self): + # Download test model without dims mentioned in params + test_model = download(test_model_path, dirname=CURR_PATH.__str__()) + input_data = np.array([0.2, 0.5]) + nd_data = mx.nd.array(input_data).expand_dims(0) + sym, arg_params, aux_params = onnx_mxnet.import_model(test_model) + model_metadata = onnx_mxnet.get_model_metadata(test_model) + input_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')] + output_data = forward_pass(sym, arg_params, aux_params, input_names, nd_data) + assert(output_data.shape == (1,1)) # test_case = ("model name", input shape, output shape) test_cases = [