Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Mar 25, 2022
1 parent 0857203 commit d256668
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def _sample_mse(self):
def _sample_emd(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
Expand All @@ -664,7 +664,7 @@ def _sample_emd(self):
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def test_post_training_mse(self):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
Expand All @@ -265,10 +266,10 @@ def test_post_training_mse(self):
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations)


class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "emd"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
Expand All @@ -511,9 +512,9 @@ def test_post_training_avg_mobilenetv1(self):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)


if __name__ == '__main__':
Expand Down

1 comment on commit d256668

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.