diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index e954ee4523b7..8ed463fc4dd9 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2291,3 +2291,24 @@ def convert_sum(node, **kwargs): name=name ) return [node] + +@mx_op.register("shape_array") +def convert_shape(node, **kwargs): + """Map MXNet's shape_array operator attributes to onnx's Shape operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = onnx.helper.make_node( + "Shape", + [input_node], + [name], + name=name, + ) + return [node] diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index fc30dbdb06da..74c28729c8e0 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -95,7 +95,8 @@ 'test_clip' 'test_cast', 'test_depthtospace', - 'test_instancenorm' + 'test_instancenorm', + 'test_shape' ] BASIC_MODEL_TESTS = [