Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Roshrini committed Jan 25, 2019
1 parent 4bfb299 commit a2d5cc5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,7 @@ def convert_roipooling(node, **kwargs):
return [node]


@mx_op.register("Tile")
@mx_op.register("tile")
def convert_tile(node, **kwargs):
"""Map MXNet's Tile operator attributes to onnx's Tile
operator and return the created node.
Expand Down
19 changes: 18 additions & 1 deletion tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper, helper, load_model
from onnx import checker, numpy_helper, helper, load_model
from onnx import TensorProto
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
Expand Down Expand Up @@ -206,6 +206,18 @@ def test_imports(self):
mxnet_out = bkd_rep.run(inputs)
npt.assert_almost_equal(np_out, mxnet_out, decimal=4)

def test_exports(self):
input_shape = (2,1,3,1)
for test in export_test_cases:
test_name, onnx_name, mx_op, attrs = test
input_sym = mx.sym.var('data')
outsym = mx_op(input_sym, **attrs)
converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32,
onnx_file_path=outsym.name + ".onnx")
model = load_model(converted_model)
checker.check_model(model)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
Expand Down Expand Up @@ -274,5 +286,10 @@ def test_imports(self):
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
]

export_test_cases = [
("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)})
]

if __name__ == '__main__':
unittest.main()

0 comments on commit a2d5cc5

Please sign in to comment.