diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index facdcfedcbca..86767a667128 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1613,3 +1613,35 @@ def convert_broadcast_equal(node, **kwargs): and return the created node. """ return create_basic_op_node('Equal', node, kwargs) + + +@mx_op.register("broadcast_logical_and") +def convert_broadcast_logical_and(node, **kwargs): + """Map MXNet's broadcast logical and operator attributes to onnx's Add operator + and return the created node. + """ + return create_basic_op_node('And', node, kwargs) + + +@mx_op.register("broadcast_logical_or") +def convert_broadcast_logical_or(node, **kwargs): + """Map MXNet's broadcast logical or operator attributes to onnx's Or operator + and return the created node. + """ + return create_basic_op_node('Or', node, kwargs) + + +@mx_op.register("broadcast_logical_xor") +def convert_broadcast_logical_xor(node, **kwargs): + """Map MXNet's broadcast logical xor operator attributes to onnx's Xor operator + and return the created node. + """ + return create_basic_op_node('Xor', node, kwargs) + + +@mx_op.register("logical_not") +def convert_logical_not(node, **kwargs): + """Map MXNet's logical not operator attributes to onnx's Not operator + and return the created node. + """ + return create_basic_op_node('Not', node, kwargs) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 964d0e760cae..6b858f05e24f 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -268,6 +268,45 @@ def test_ops(op_name, inputs, input_tensors, numpy_op): test_ops("Equal", input_data, input_tensor, np.equal(input_data[0], input_data[1]).astype(np.float32)) + +def get_int_inputs(interval, shape): + """Helper to get integer input of given shape and range""" + assert len(interval) == len(shape) + inputs = [] + input_tensors = [] + for idx in range(len(interval)): + low, high = interval[idx] + inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32")) + input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1), + TensorProto.FLOAT, shape=shape[idx])) + return inputs, input_tensors + + +@with_seed() +def test_logical_ops(): + """Test for logical and, or, not, xor operators""" + def test_ops(op_name, inputs, input_tensors, numpy_op): + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node(op_name, ["input"+str(i+1) for i in range(len(inputs))], ["output"])] + graph = helper.make_graph(nodes, + op_name + "_test", + input_tensors, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run(inputs) + npt.assert_almost_equal(output[0], numpy_op) + input_data, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)]) + test_ops("And", input_data, input_tensor, + np.logical_and(input_data[0], input_data[1]).astype(np.float32)) + test_ops("Or", input_data, input_tensor, + np.logical_or(input_data[0], input_data[1]).astype(np.float32)) + test_ops("Xor", input_data, input_tensor, + np.logical_xor(input_data[0], input_data[1]).astype(np.float32)) + test_ops("Not", [input_data[0]], [input_tensor[0]], + np.logical_not(input_data[0]).astype(np.float32)) + + def _assert_sym_equal(lhs, rhs): assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index aed68ffa114c..f41fe92352db 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -55,7 +55,6 @@ 'test_argmax', 'test_argmin', 'test_min', - 'test_logical_', # enabling partial test cases for matmul 'test_matmul_3d', 'test_matmul_4d',