Skip to content

Commit

Permalink
Specifying input data types
Browse files Browse the repository at this point in the history
  • Loading branch information
eee4017 authored and Frank Lin (Engrg-Hardware 1) committed May 19, 2023
1 parent ab770f8 commit f15b894
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
1 change: 1 addition & 0 deletions test/ir/inference/auto_scan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ def random_to_skip():
threshold,
) in self.sample_predictor_configs(prog_config):
for input_type in self.get_avalible_input_type():
prog_config = prog_config.set_input_type(input_type)
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)

Expand Down
31 changes: 24 additions & 7 deletions test/ir/inference/test_multihead_matmul_fuse_pass_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def generate_mul_input():
def generate_elewise_input():
return np.random.random([1, 12, 128, 128]).astype(np.float32)

def generate_weight(shape):
return np.random.random(shape).astype(np.float32)

mul_0 = OpConfig(
"mul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
Expand Down Expand Up @@ -195,13 +198,27 @@ def generate_elewise_input():
),
},
weights={
"mul_0_w": TensorConfig(shape=[768, 768]),
"mul_1_w": TensorConfig(shape=[768, 768]),
"mul_2_w": TensorConfig(shape=[768, 768]),
"mul_3_w": TensorConfig(shape=[768, 768]),
"ele_0_w": TensorConfig(shape=[768]),
"ele_1_w": TensorConfig(shape=[768]),
"ele_2_w": TensorConfig(shape=[768]),
"mul_0_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_1_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_2_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_3_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"ele_0_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_1_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_2_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
},
outputs=[ops[-1].outputs["Out"][0]],
)
Expand Down

0 comments on commit f15b894

Please sign in to comment.