Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ONNX export: Instance normalization, Shape #12920

Merged
merged 3 commits into from
Dec 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,23 @@ def convert_identity(node, **kwargs):
"""
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("InstanceNorm")
def convert_instancenorm(node, **kwargs):
"""Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator
based on the input node's attributes and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

eps = float(attrs.get("eps", 0.001))

node = onnx.helper.make_node(
'InstanceNormalization',
inputs=input_nodes,
outputs=[name],
name=name,
epsilon=eps)

return [node]

@mx_op.register("LeakyReLU")
def convert_leakyrelu(node, **kwargs):
Expand Down Expand Up @@ -1546,6 +1563,15 @@ def convert_sum(node, **kwargs):
)
return [node]


@mx_op.register("shape_array")
def convert_shape(node, **kwargs):
"""Map MXNet's shape_array operator attributes to onnx's Shape operator
and return the created node.
"""
return create_basic_op_node('Shape', node, kwargs)


@mx_op.register("hard_sigmoid")
def convert_hardsigmoid(node, **kwargs):
"""Map MXNet's hard_sigmoid operator attributes to onnx's HardSigmoid operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
# under the License.

# coding: utf-8
"""backend rep for onnx test infrastructure"""
"""MXNet backend rep for onnx test infrastructure"""
try:
from onnx.backend.base import BackendRep
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - /~https://github.com/onnx/onnx#installation")
import mxnet as mx

# Using these functions for onnx test infrastructure.
# Implemented by following onnx docs guide:
# /~https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
# /~https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
# MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to
# execute a model repeatedly.
# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and
Expand Down Expand Up @@ -54,9 +55,6 @@ def run(self, inputs, **kwargs):
params : numpy array
result obtained after running the inference on mxnet
"""
data_forward = []
for val in inputs:
data_forward.append(mx.nd.array(val))
# create module, passing cpu context
if self.device == 'CPU':
ctx = mx.cpu()
Expand All @@ -68,17 +66,19 @@ def run(self, inputs, **kwargs):
data_names = [graph_input for graph_input in self.symbol.list_inputs()
if graph_input not in self.arg_params and graph_input not in self.aux_params]

data_shapes = []
data_forward = []
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))
val = inputs[idx]
data_forward.append(mx.nd.array(val))

mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
if self.arg_params:
for idx, input_name in enumerate(self.arg_params):
val = self.arg_params[input_name]
data_names.append(input_name)
data_forward.append(mx.nd.array(val))

# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
4 changes: 4 additions & 0 deletions tests/python-pytest/onnx/export/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# coding: utf-8
"""backend wrapper for onnx test infrastructure"""
import os
import sys
import numpy as np
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph
Expand All @@ -25,6 +27,8 @@
from onnx.backend.base import Backend
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(CURR_PATH, '../'))
from backend_rep import MXNetBackendRep

# Using these functions for onnx test infrastructure.
Expand Down
4 changes: 3 additions & 1 deletion tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
'test_clip'
'test_cast',
'test_depthtospace',
'test_hardsigmoid'
'test_hardsigmoid',
'test_instancenorm',
'test_shape'
]

BASIC_MODEL_TESTS = [
Expand Down
6 changes: 5 additions & 1 deletion tests/python-pytest/onnx/import/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

# coding: utf-8
"""MXNet backend wrapper for onnx test infrastructure"""
import os
import sys
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
try:
from onnx import helper, TensorProto
from onnx.backend.base import Backend
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - /~https://github.com/onnx/onnx#installation")
from mxnet_backend_rep import MXNetBackendRep
CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(CURR_PATH, '../'))
from backend_rep import MXNetBackendRep

# MXNetBackend class will take an ONNX model with inputs, perform a computation,
# and then return the output.
Expand Down
98 changes: 0 additions & 98 deletions tests/python-pytest/onnx/import/mxnet_backend_rep.py

This file was deleted.