diff --git a/lite/tests/unittest_py/op/test_greater_op.py b/lite/tests/unittest_py/op/test_greater_op.py index ea2c4dff8a0..f3644877cc7 100644 --- a/lite/tests/unittest_py/op/test_greater_op.py +++ b/lite/tests/unittest_py/op/test_greater_op.py @@ -41,10 +41,14 @@ def __init__(self, *args, **kwargs): def is_program_valid(self, program_config: ProgramConfig, predictor_config: CxxConfig) -> bool: + in_dtype = program_config.inputs["data_x"].dtype + + if "int32" == in_dtype: + return False return True def sample_program_configs(self, draw): - in_shape = draw( + in_shape_x = draw( st.lists( st.integers( min_value=3, max_value=10), min_size=3, max_size=4)) @@ -53,14 +57,10 @@ def sample_program_configs(self, draw): process_type = draw( st.sampled_from(["type_int64", "type_float", "type_int32"])) - ############### ToDo #################### - assume(process_type != "type_int32") - ######################################### - if axis == -1: - in_shape_y = in_shape + in_shape_y = in_shape_x else: - in_shape_y = in_shape[axis:] + in_shape_y = in_shape_x[axis:] def generate_data(type, size_list): if type == "type_int32": @@ -73,7 +73,7 @@ def generate_data(type, size_list): return np.random.random(size=size_list).astype(np.float32) def generate_input_x(): - return generate_data(process_type, in_shape) + return generate_data(process_type, in_shape_x) def generate_input_y(): return generate_data(process_type, in_shape_y) @@ -84,16 +84,16 @@ def generate_input_y(): "Y": ["data_y"]}, outputs={"Out": ["output_data"], }, attrs={"axis": axis, - "force_cpu": True}) - build_ops.outputs_dtype = {"output_data": np.bool_} + "force_cpu": True}, + outputs_dtype={"output_data": np.bool_}) cast_out = OpConfig( type="cast", inputs={"X": ["output_data"], }, outputs={"Out": ["cast_data_out"], }, attrs={"in_dtype": int(0), - "out_dtype": int(2)}) - cast_out.outputs_dtype = {"cast_data_out": np.int32} + "out_dtype": int(2)}, + outputs_dtype={"cast_data_out": np.int32}) program_config = ProgramConfig( ops=[build_ops, cast_out],