From 68e0b432a3eecd8a834061b2e93f1930a2c63c2d Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 12 Oct 2018 11:07:11 -0700 Subject: [PATCH] ONNX export: Instance Normalization --- .../contrib/onnx/mx2onnx/_op_translations.py | 35 +++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 3 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..78b36a0f3b5b 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -850,6 +850,41 @@ def convert_identity(node, **kwargs): ) return [node] +@mx_op.register("InstanceNorm") +def convert_instancenorm(node, **kwargs): + """Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator + based on the input node's attributes and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + gamma_node_id = kwargs["index_lookup"][inputs[1][0]] + beta_node_id = kwargs["index_lookup"][inputs[2][0]] + + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_id] + gamma_node = proc_nodes[gamma_node_id] + beta_node = proc_nodes[beta_node_id] + + input_name = input_node.name + gamma_name = gamma_node.name + beta_name = beta_node.name + + attrs = node.get("attrs", {}) + eps = float(attrs.get("eps", 0.001)) + + node = onnx.helper.make_node( + 'InstanceNormalization', + inputs=[input_name, gamma_name, beta_name], + outputs=[name], + name=name, + epsilon=eps) + + return [node] @mx_op.register("LeakyReLU") def convert_leakyrelu(node, **kwargs): diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 678435d92357..fc30dbdb06da 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -94,7 +94,8 @@ 'test_operator_permute2', 'test_clip' 'test_cast', - 'test_depthtospace' + 'test_depthtospace', + 'test_instancenorm' ] BASIC_MODEL_TESTS = [