Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Use appropriate bind method based on batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 23, 2018
1 parent fce5154 commit 88455db
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
38 changes: 23 additions & 15 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,39 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
if graph_input not in arg_params and graph_input not in aux_params
and graph_input != output_label]

data_forward = []
data_shapes = []
# Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))
val = inputs[idx]
data_shapes.append((input_name, val.shape))
data_forward.append(nd.array(val))

# create module, passing cpu context
ctx = context.cpu()
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

# initializing parameters for calculating result of each individual node
if arg_params is None and aux_params is None:
test_mod.init_params()
else:
test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)
# module bind method requires all data to have same batch size,
# using module if all data have same batch size
if len(set([data_shape[1][0] for data_shape in data_shapes])) == 1:
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

data_forward = []
for idx, input_name in enumerate(data_names):
val = inputs[idx]
data_forward.append(nd.array(val))
# initializing parameters for calculating result of each individual node
if arg_params is None and aux_params is None:
test_mod.init_params()
else:
test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)

test_mod.forward(io.DataBatch(data_forward))
result = test_mod.get_outputs()[0].asnumpy()
test_mod.forward(io.DataBatch(data_forward))
result = test_mod.get_outputs()[0].asnumpy()

return result.shape
return result.shape
# using symbol bind method if data have different batch size
else:
exec1 = sym.bind(ctx, args=dict(zip(data_names, data_forward)))
exec1.forward(is_train=False)
result = exec1.outputs[0].asnumpy()
return result.shape


@staticmethod
Expand Down
27 changes: 18 additions & 9 deletions tests/python-pytest/onnx/export/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,22 @@ def run(self, inputs, **kwargs):
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))

mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
# module bind method requires all data to have same batch size,
# using module if all data have same batch size
if len(set([data_shape[1][0] for data_shape in data_shapes])) == 1:
mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)

# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
return [result]
# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
return [result]
# using symbol bind method if data have different batch size
else:
exec1 = self.symbol.bind(ctx, args=dict(zip(data_names, data_forward)))
exec1.forward(is_train=False)
result = exec1.outputs[0].asnumpy()
return [result]

0 comments on commit 88455db

Please sign in to comment.