diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 2baa90392907..04f6e37dc922 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -259,6 +259,53 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) + +@with_seed() +def test_fully_connected(): + def random_arrays(*shapes): + """Generate some random numpy arrays.""" + arrays = [np.random.randn(*s).astype("float32") + for s in shapes] + if len(arrays) == 1: + return arrays[0] + return arrays + + data_names = ['x', 'w', 'b'] + + dim_in, dim_out = (3, 4) + input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,)) + + ipsym = [] + data_shapes = [] + data_forward = [] + for idx in range(len(data_names)): + val = input_data[idx] + data_shapes.append((data_names[idx], np.shape(val))) + data_forward.append(mx.nd.array(val)) + ipsym.append(mx.sym.Variable(data_names[idx])) + + op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC') + + model = mx.mod.Module(op, data_names=data_names, label_names=None) + model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) + + model.init_params() + + args, auxs = model.get_params() + params = {} + params.update(args) + params.update(auxs) + + converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx") + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + + numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2] + + npt.assert_almost_equal(result, numpy_op) + + @with_seed() def test_comparison_ops(): """Test greater, lesser, equal"""