Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add tests for multinomial, lppool, globallppool
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 28, 2018
1 parent 977065c commit 30bbd1a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sample_multinomial(attrs, inputs, proto_obj):
+ "Instructions to install - /~https://github.com/onnx/onnx")
new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
new_attrs = translation_utils._fix_attribute_names(new_attrs, {'sample_size': 'shape'})
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(new_attrs['dtype'])]
new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
return 'sample_multinomial', new_attrs, inputs


Expand Down
81 changes: 65 additions & 16 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
return np.random.choice(a=[False, True], size=shape).astype(np.float32)


def _fix_attributes(attrs, attribute_mapping):
new_attrs = attrs
attr_modify = attribute_mapping.get('modify', {})
for k, v in attr_modify.items():
new_attrs[v] = new_attrs.pop(k, None)

attr_add = attribute_mapping.get('add', {})
for k, v in attr_add.items():
new_attrs[k] = v

attr_remove = attribute_mapping.get('remove', [])
for k in attr_remove:
if k in new_attrs:
del new_attrs[k]

return new_attrs


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data
:param sym: Symbol
Expand Down Expand Up @@ -118,7 +136,7 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
return model

for test in test_cases:
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
test_op = mxnet_op(*inputsym, **attrs)
Expand All @@ -131,33 +149,64 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs)
onnx_attrs = _fix_attributes(attrs, fix_attrs)
onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs)

bkd_rep = backend.prepare(onnxmodel, operation='export')
output = bkd_rep.run(inputs)

npt.assert_almost_equal(output[0], mxnet_output)
if check_value:
npt.assert_almost_equal(output[0], mxnet_output)

if check_shape:
npt.assert_equal(output[0].shape, outputshape)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False)
# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
# 'add': {attr_name: value},
# check_value=True/False, check_shape=True/False)
test_cases = [
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False),
("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True,
False),
("test_and", mx.sym.broadcast_logical_and, "And",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_xor", mx.sym.broadcast_logical_xor, "Xor",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_or", mx.sym.broadcast_logical_or, "Or",
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True),
[get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, False),
("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, False),
("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))],
{'block_size': 2}, False),
{'block_size': 2}, False, {}, True, False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True)
{'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp'}, False,
{'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 'p_value': 'p'},
'remove': ['pool_type']}, True, False),
("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 20))],
{'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 'pool_type': 'lp', 'global_pool': True}, False,
{'modify': {'p_value': 'p'},
'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, True, False),
("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
[np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
]

if __name__ == '__main__':
Expand Down

0 comments on commit 30bbd1a

Please sign in to comment.