Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: Use sym bind method
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 26, 2018
1 parent ee5f699 commit 9685aa6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 56 deletions.
30 changes: 12 additions & 18 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,24 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
if graph_input not in arg_params and graph_input not in aux_params
and graph_input != output_label]

data_shapes = []
# Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))

# create module, passing cpu context
ctx = context.cpu()
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

# initializing parameters for calculating result of each individual node
if arg_params is None and aux_params is None:
test_mod.init_params()
else:
test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)

data_forward = []
for idx, input_name in enumerate(data_names):
val = inputs[idx]
data_forward.append(nd.array(val))

test_mod.forward(io.DataBatch(data_forward))
result = test_mod.get_outputs()[0].asnumpy()
if arg_params:
for idx, input_name in enumerate(arg_params):
val = arg_params[input_name]
data_names.append(input_name)
data_forward.append(nd.array(val))

# create module, passing cpu context
ctx = context.cpu()

args = dict(zip(data_names, data_forward))
exe = sym.bind(ctx, args=args, aux_states=aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return result.shape


Expand Down
26 changes: 13 additions & 13 deletions tests/python-pytest/onnx/export/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def run(self, inputs, **kwargs):
params : numpy array
result obtained after running the inference on mxnet
"""
data_forward = []
for val in inputs:
data_forward.append(mx.nd.array(val))

# create module, passing cpu context
if self.device == 'CPU':
ctx = mx.cpu()
Expand All @@ -68,17 +66,19 @@ def run(self, inputs, **kwargs):
data_names = [graph_input for graph_input in self.symbol.list_inputs()
if graph_input not in self.arg_params and graph_input not in self.aux_params]

data_shapes = []
data_forward = []
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))
val = inputs[idx]
data_forward.append(mx.nd.array(val))

mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
if self.arg_params:
for idx, input_name in enumerate(self.arg_params):
val = self.arg_params[input_name]
data_names.append(input_name)
data_forward.append(mx.nd.array(val))

# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
35 changes: 10 additions & 25 deletions tests/python-pytest/onnx/import/mxnet_backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,15 @@ def run(self, inputs, **kwargs):
data_names = [graph_input for graph_input in self.symbol.list_inputs()
if graph_input not in self.arg_params and graph_input not in self.aux_params]

data_shapes = []
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))
if self.arg_params:
for idx, input_name in enumerate(self.arg_params):
val = self.arg_params[input_name]
data_names.append(input_name)
data_forward.append(mx.nd.array(val))

# module bind method requires all data to have same batch size,
# using module if all data have same batch size
if len(set([data_shape[1][0] for data_shape in data_shapes])) == 1:
mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)

# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
# split operator inference returns 1 less dimension
if self.symbol.name.startswith('split'):
return [i.asnumpy() for i in mod.get_outputs()]
return [result]
# using symbol bind method if data have different batch size
else:
exec1 = self.symbol.bind(ctx, args=dict(zip(data_names, data_forward)))
exec1.forward(is_train=False)
result = exec1.outputs[0].asnumpy()
return [result]
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]

0 comments on commit 9685aa6

Please sign in to comment.