From d8de0f48871a380050b59c415227dc1313bc665b Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 27 Dec 2018 13:21:53 -0800 Subject: [PATCH] tests for maxroipool, randomnormal, randomuniform --- tests/python-pytest/onnx/test_node.py | 97 +++++++++++++++++++++------ 1 file changed, 76 insertions(+), 21 deletions(-) diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 07ae866b96cf..ee4d0ecd6d6e 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -27,13 +27,11 @@ import os import unittest import logging -import tarfile from collections import namedtuple import numpy as np import numpy.testing as npt from onnx import numpy_helper, helper, load_model from onnx import TensorProto -from mxnet.test_utils import download from mxnet.contrib import onnx as onnx_mxnet import mxnet as mx import backend @@ -56,6 +54,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32): return np.random.choice(a=[False, True], size=shape).astype(np.float32) +def _fix_attributes(attrs, attribute_mapping): + new_attrs = attrs + attr_modify = attribute_mapping.get('modify', {}) + for k, v in attr_modify.items(): + new_attrs[v] = new_attrs.pop(k, None) + + attr_add = attribute_mapping.get('add', {}) + for k, v in attr_add.items(): + new_attrs[k] = v + + attr_remove = attribute_mapping.get('remove', []) + for k in attr_remove: + if k in new_attrs: + del new_attrs[k] + + return new_attrs + + def forward_pass(sym, arg, aux, data_names, input_data): """ Perform forward pass on given data :param sym: Symbol @@ -118,12 +134,21 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att return model for test in test_cases: - test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = test + test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test with self.subTest(test_name): names, input_tensors, inputsym = get_input_tensors(inputs) - test_op = mxnet_op(*inputsym, **attrs) - mxnet_output = forward_pass(test_op, None, None, names, inputs) - outputshape = np.shape(mxnet_output) + if inputs: + test_op = mxnet_op(*inputsym, **attrs) + mxnet_output = forward_pass(test_op, None, None, names, inputs) + outputshape = np.shape(mxnet_output) + else: + test_op = mxnet_op(**attrs) + shape = attrs.get('shape', (1,)) + x = mx.nd.zeros(shape, dtype='float32') + xgrad = mx.nd.zeros(shape, dtype='float32') + exe = test_op.bind(ctx=mx.cpu(), args={'x': x}, args_grad={'x': xgrad}) + mxnet_output = exe.forward(is_train=False)[0].asnumpy() + outputshape = np.shape(mxnet_output) if mxnet_specific: onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs], @@ -131,33 +156,63 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att onnx_name + ".onnx") onnxmodel = load_model(onnxmodelfile) else: - onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, attrs) + onnx_attrs = _fix_attributes(attrs, fix_attrs) + onnxmodel = get_onnx_graph(test_name, names, input_tensors, onnx_name, outputshape, onnx_attrs) bkd_rep = backend.prepare(onnxmodel, operation='export') output = bkd_rep.run(inputs) - npt.assert_almost_equal(output[0], mxnet_output) + if check_value: + npt.assert_almost_equal(output[0], mxnet_output) + + if check_shape: + npt.assert_equal(output[0].shape, outputshape) -# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False) +# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False, +# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name}, +# 'remove': [attr_name], +# 'add': {attr_name: value}, +# check_value=True/False, check_shape=True/False) test_cases = [ - ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), - ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), - ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False), + ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), + ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), + ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), get_rnd((1, 5))], {}, False, {}, True, + False), ("test_and", mx.sym.broadcast_logical_and, "And", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, + False), ("test_xor", mx.sym.broadcast_logical_xor, "Xor", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, + False), ("test_or", mx.sym.broadcast_logical_or, "Or", - [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), - ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False), - ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True), + [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, + False), + ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), dtype=np.bool_)], {}, False, {}, True, + False), + ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], {}, True, {}, True, + False), ("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 1, 4, 6))], - {'block_size': 2}, False), + {'block_size': 2}, False, {}, True, + False), ("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)], - {'ignore_label': 0, 'use_ignore': False}, True), - ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), get_rnd((4, 3)), get_rnd(4)], - {'num_hidden': 4, 'name': 'FC'}, True) + {'ignore_label': 0, 'use_ignore': False}, True, {}, True, + False), + ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)], + {'num_hidden': 4, 'name': 'FC'}, True, {}, True, + False), + ("test_roipool", mx.sym.ROIPooling, "MaxRoiPool", + [[[get_rnd(shape=(8, 6), low=1, high=100, dtype=np.int32)]], [[0, 0, 0, 4, 4]]], + {'pooled_size': (2, 2), 'spatial_scale': 0.7}, False, + {'modify': {'pooled_size': 'pooled_shape'}}, True, False), + + # since results would be random, checking for shape alone + ("test_random_normal", mx.sym.random_normal, "RandomNormal", [], + {'shape': (2, 2), 'loc': 0, 'scale': 1}, False, {'modify': {'loc': 'mean'}}, False, True), + ("test_random_uniform", mx.sym.random_uniform, "RandomUniform", [], + {'shape': (2, 2), 'low': 0.5, 'high': 1.0}, False, {}, False, True) ] if __name__ == '__main__':