Skip to content

Commit

Permalink
[MXNET-898] ONNX import/export: Sample_multinomial, ONNX export: Glob…
Browse files Browse the repository at this point in the history
…alLpPool, LpPool (apache#13500)

* ONNX import/export: Sample_multinomial

* ONNX export: GlobalLpPool, LpPool

* Handle default p_value

* Add tests for multinomial, lppool, globallppool

* add a comment about shape test
  • Loading branch information
vandanavk authored and haohuw committed Jun 23, 2019
1 parent f43828c commit 1cc25a8
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 40 deletions.
83 changes: 66 additions & 17 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def convert_pooling(node, **kwargs):
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = attrs.get('p_value', 'None')

pooling_convention = attrs.get('pooling_convention', 'valid')

Expand All @@ -598,26 +599,51 @@ def convert_pooling(node, **kwargs):

pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
pad_dims = pad_dims + pad_dims
pool_types = {"max": "MaxPool", "avg": "AveragePool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
"lp": "GlobalLpPool"}

if pool_type == 'lp' and p_value == 'None':
raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool')

if global_pool:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
name=name
)
else:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
if pool_type == 'lp':
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)

return [node]

Expand Down Expand Up @@ -1689,3 +1715,26 @@ def convert_logsoftmax(node, **kwargs):
name=name
)
return [node]


@mx_op.register("_sample_multinomial")
def convert_multinomial(node, **kwargs):
"""Map MXNet's multinomial operator attributes to onnx's
Multinomial operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)
dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get("dtype", 'int32'))]
sample_size = convert_string_to_list(attrs.get("shape", '1'))
if len(sample_size) < 2:
sample_size = sample_size[-1]
else:
raise AttributeError("ONNX currently supports integer sample_size only")
node = onnx.helper.make_node(
"Multinomial",
input_nodes,
[name],
dtype=dtype,
sample_size=sample_size,
name=name,
)
return [node]
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8_
# pylint: disable=invalid-name
"""Operator attributes conversion"""
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import identity, random_uniform, random_normal, sample_multinomial
from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
Expand Down Expand Up @@ -48,6 +48,7 @@
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
'RandomNormalLike' : random_normal,
'Multinomial' : sample_multinomial,
# Arithmetic Operators
'Add' : add,
'Sub' : subtract,
Expand Down
21 changes: 18 additions & 3 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def random_normal(attrs, inputs, proto_obj):
new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'})
return 'random_uniform', new_attr, inputs

def sample_multinomial(attrs, inputs, proto_obj):
"""Draw random samples from a multinomial distribution."""
try:
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - /~https://github.com/onnx/onnx")
new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
new_attrs = translation_utils._fix_attribute_names(new_attrs, {'sample_size': 'shape'})
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
return 'sample_multinomial', new_attrs, inputs


# Arithmetic Operations
def add(attrs, inputs, proto_obj):
"""Adding two tensors"""
Expand Down Expand Up @@ -382,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj):
'kernel': (1, 1),
'pool_type': 'lp',
'p_value': p_value})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
return 'Pooling', new_attrs, inputs

def linalg_gemm(attrs, inputs, proto_obj):
Expand Down Expand Up @@ -671,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj):
new_attrs = translation_utils._fix_attribute_names(attrs,
{'kernel_shape': 'kernel',
'strides': 'stride',
'pads': 'pad',
'p_value': p_value
'pads': 'pad'
})
new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'pooling_convention': 'valid'
{'pooling_convention': 'valid',
'p_value': p_value
})
new_op = translation_utils._fix_pooling('lp', inputs, new_attrs)
return new_op, new_attrs, inputs
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr):
stride = new_attr.get('stride')
kernel = new_attr.get('kernel')
padding = new_attr.get('pad')
p_value = new_attr.get('p_value')

# Adding default stride.
if stride is None:
Expand Down Expand Up @@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr):
new_pad_op = symbol.pad(curr_sym, mode='constant', pad_width=pad_width)

# Apply pooling without pads.
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
if pool_type == 'lp':
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel, p_value=p_value)
else:
new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, stride=stride, kernel=kernel)
return new_pooling_op

def _fix_bias(op_name, attrs, num_inputs):
Expand Down
2 changes: 0 additions & 2 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
'test_softplus'
],
'import': ['test_gather',
'test_global_lppooling',
'test_softsign',
'test_reduce_',
'test_mean',
Expand All @@ -89,7 +88,6 @@
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_LpPool_',
'test_split_equal',
'test_hardmax'
],
Expand Down
83 changes: 67 additions & 16 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
return np.random.choice(a=[False, True], size=shape).astype(np.float32)


def _fix_attributes(attrs, attribute_mapping):
new_attrs = attrs
attr_modify = attribute_mapping.get('modify', {})
for k, v in attr_modify.items():
new_attrs[v] = new_attrs.pop(k, None)

attr_add = attribute_mapping.get('add', {})
for k, v in attr_add.items():
new_attrs[k] = v

attr_remove = attribute_mapping.get('remove', [])
for k in attr_remove:
if k in new_attrs:
del new_attrs[k]

return new_attrs


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data
:param sym: Symbol
Expand Down Expand Up @@ -118,7 +136,7 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
return model

for test in test_cases:
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
test_op = mxnet_op(*inputsym, **attrs)
Expand All @@ -131,33 +149,66 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs)
onnx_attrs = _fix_attributes(attrs, fix_attrs)
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs)

bkd_rep = backend.prepare(onnxmodel, operation='export')
output = bkd_rep.run(inputs)

npt.assert_almost_equal(output[0], mxnet_output)
if check_value:
npt.assert_almost_equal(output[0], mxnet_output)

if check_shape:
npt.assert_equal(output[0].shape, outputshape)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False)
# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
# 'add': {attr_name: value},
# check_value=True/False, check_shape=True/False)
test_cases = [
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_and", mx.sym.broadcast_logical_and, "And",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_xor", mx.sym.broadcast_logical_xor, "Xor",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_or", mx.sym.broadcast_logical_or, "Or",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, False),
("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))],
{'block_size': 2}, False),
{'block_size': 2}, False, {}, True, False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True)
{'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),

# since results would be random, checking for shape alone
("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
[np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
]

if __name__ == '__main__':
Expand Down

0 comments on commit 1cc25a8

Please sign in to comment.