Skip to content

Commit

Permalink
modeify some code test=document_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
daming5432 committed Dec 24, 2021
1 parent 9fc99b3 commit 630d7a1
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions lite/tests/unittest_py/op/test_greater_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ def __init__(self, *args, **kwargs):
def is_program_valid(self,
program_config: ProgramConfig,
predictor_config: CxxConfig) -> bool:
in_dtype = program_config.inputs["data_x"].dtype

if "int32" == in_dtype:
return False
return True

def sample_program_configs(self, draw):
in_shape = draw(
in_shape_x = draw(
st.lists(
st.integers(
min_value=3, max_value=10), min_size=3, max_size=4))
Expand All @@ -53,14 +57,10 @@ def sample_program_configs(self, draw):
process_type = draw(
st.sampled_from(["type_int64", "type_float", "type_int32"]))

############### ToDo ####################
assume(process_type != "type_int32")
#########################################

if axis == -1:
in_shape_y = in_shape
in_shape_y = in_shape_x
else:
in_shape_y = in_shape[axis:]
in_shape_y = in_shape_x[axis:]

def generate_data(type, size_list):
if type == "type_int32":
Expand All @@ -73,7 +73,7 @@ def generate_data(type, size_list):
return np.random.random(size=size_list).astype(np.float32)

def generate_input_x():
return generate_data(process_type, in_shape)
return generate_data(process_type, in_shape_x)

def generate_input_y():
return generate_data(process_type, in_shape_y)
Expand All @@ -84,16 +84,16 @@ def generate_input_y():
"Y": ["data_y"]},
outputs={"Out": ["output_data"], },
attrs={"axis": axis,
"force_cpu": True})
build_ops.outputs_dtype = {"output_data": np.bool_}
"force_cpu": True},
outputs_dtype={"output_data": np.bool_})

cast_out = OpConfig(
type="cast",
inputs={"X": ["output_data"], },
outputs={"Out": ["cast_data_out"], },
attrs={"in_dtype": int(0),
"out_dtype": int(2)})
cast_out.outputs_dtype = {"cast_data_out": np.int32}
"out_dtype": int(2)},
outputs_dtype={"cast_data_out": np.int32})

program_config = ProgramConfig(
ops=[build_ops, cast_out],
Expand Down

0 comments on commit 630d7a1

Please sign in to comment.