From 8fe111b1cd0c17789fba19d8653584cf02f4e5e4 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 5 Jul 2022 17:49:52 +0800 Subject: [PATCH] change threshold for ptq hpo (#1254) --- paddleslim/auto_compression/auto_strategy.py | 4 ++-- paddleslim/quant/post_quant_hpo.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddleslim/auto_compression/auto_strategy.py b/paddleslim/auto_compression/auto_strategy.py index 2826601ae..451f60075 100644 --- a/paddleslim/auto_compression/auto_strategy.py +++ b/paddleslim/auto_compression/auto_strategy.py @@ -77,8 +77,8 @@ MAGIC_SPARSE_RATIO = 0.75 ### TODO: 0.02 threshold maybe not suitable, need to check ### NOTE: reduce magic data to choose quantization aware training. -MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02 -MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01 +MAGIC_MAX_EMD_DISTANCE = 0.00002 #0.02 +MAGIC_MIN_EMD_DISTANCE = 0.00001 #0.01 DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8' DEFAULT_STRATEGY = 'origin_int8' diff --git a/paddleslim/quant/post_quant_hpo.py b/paddleslim/quant/post_quant_hpo.py index 9f9275f8a..adb5ffa5c 100755 --- a/paddleslim/quant/post_quant_hpo.py +++ b/paddleslim/quant/post_quant_hpo.py @@ -144,7 +144,12 @@ def standardization(data): """standardization numpy array""" mu = np.mean(data, axis=0) sigma = np.std(data, axis=0) - sigma = 1e-13 if sigma == 0. else sigma + if isinstance(sigma, list) or isinstance(sigma, np.ndarray): + for idx, sig in enumerate(sigma): + if sig == 0.: + sigma[idx] = 1e-13 + else: + sigma = 1e-13 if sigma == 0. else sigma return (data - mu) / sigma @@ -241,18 +246,15 @@ def eval_quant_model(): if have_invalid_num(out_float) or have_invalid_num(out_quant): continue - try: - out_float = standardization(out_float) - out_quant = standardization(out_quant) - except: - continue - out_float_list.append(out_float) - out_quant_list.append(out_quant) + out_float_list.append(list(out_float)) + out_quant_list.append(list(out_quant)) valid_data_num += 1 if valid_data_num >= max_eval_data_num: break + out_float_list = standardization(out_float_list) + out_quant_list = standardization(out_quant_list) emd_sum = cal_emd_lose(out_float_list, out_quant_list, out_len_sum / float(valid_data_num)) _logger.info("output diff: {}".format(emd_sum))