From 3560dd548ec3f877bba93693bbbd19eb6ceb089a Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 12 Oct 2018 11:07:11 -0700 Subject: [PATCH] ONNX export: Instance Normalization --- .../contrib/onnx/mx2onnx/_op_translations.py | 26 +++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..e954ee4523b7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -850,6 +850,32 @@ def convert_identity(node, **kwargs): ) return [node] +@mx_op.register("InstanceNorm") +def convert_instancenorm(node, **kwargs): + """Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator + based on the input node's attributes and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_nodes = [] + for ip in inputs: + input_node_id = kwargs["index_lookup"][ip[0]] + input_nodes.append(proc_nodes[input_node_id].name) + + attrs = node.get("attrs", {}) + eps = float(attrs.get("eps", 0.001)) + + node = onnx.helper.make_node( + 'InstanceNormalization', + inputs=input_nodes, + outputs=[name], + name=name, + epsilon=eps) + + return [node] @mx_op.register("LeakyReLU") def convert_leakyrelu(node, **kwargs): diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 678435d92357..fc30dbdb06da 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -94,7 +94,8 @@ 'test_operator_permute2', 'test_clip' 'test_cast', - 'test_depthtospace' + 'test_depthtospace', + 'test_instancenorm' ] BASIC_MODEL_TESTS = [