From e9011d3e9cb09fe355867e0b81fbdf9a413de78e Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Mon, 21 Aug 2023 15:18:51 +0800 Subject: [PATCH] Support Load scales from quant model (#1790) --- paddleslim/quant/advanced/piecewise_search.py | 4 ++-- paddleslim/quant/observers/abs_max.py | 3 +++ paddleslim/quant/observers/abs_max_weight.py | 3 ++- paddleslim/quant/observers/avg.py | 3 +++ paddleslim/quant/observers/emd.py | 3 +++ paddleslim/quant/observers/mse.py | 3 +++ 6 files changed, 16 insertions(+), 3 deletions(-) diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py index 6b17286b742194..55678409b43bb8 100644 --- a/paddleslim/quant/advanced/piecewise_search.py +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -153,8 +153,8 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): else: smooth_scale_out += final_smooth_scale - if cur_loss < global_loss: - global_loss = cur_loss + if calibration_loss < global_loss: + global_loss = calibration_loss best_scale = smooth_scale_out if self.search_piece: print('Find Better K-Piece {}'.format(k_piece)) diff --git a/paddleslim/quant/observers/abs_max.py b/paddleslim/quant/observers/abs_max.py index 10cdf2983563a7..a8266df1e57f57 100644 --- a/paddleslim/quant/observers/abs_max.py +++ b/paddleslim/quant/observers/abs_max.py @@ -69,6 +69,9 @@ def cal_min_max(self, inputs): def cal_thresholds(self): """ Compute thresholds for MAX function. """ + if self._scale is not None: + self._zero_point = 0 + return self._scale, self._zero_point = self.cal_scales_zero_points() def min_value(self) -> float: diff --git a/paddleslim/quant/observers/abs_max_weight.py b/paddleslim/quant/observers/abs_max_weight.py index 1381cd23d8cdd3..a22fc496446db9 100644 --- a/paddleslim/quant/observers/abs_max_weight.py +++ b/paddleslim/quant/observers/abs_max_weight.py @@ -81,7 +81,8 @@ def max_value(self) -> float: def cal_thresholds(self): """ Compute thresholds for MAX function. """ - self._scale = self._max + if self._scale is None: + self._scale = self._max self._zero_point = paddle.zeros_like(self._scale) def scales(self): diff --git a/paddleslim/quant/observers/avg.py b/paddleslim/quant/observers/avg.py index 14b6ba81de5188..199a2aa0e6205c 100644 --- a/paddleslim/quant/observers/avg.py +++ b/paddleslim/quant/observers/avg.py @@ -70,6 +70,9 @@ def cal_min_max(self, inputs): def cal_thresholds(self): """ Compute thresholds for MAX function. """ + if self._scale is not None: + self._zero_point = 0 + return self._min, self._max = self._avg_min, paddle.mean( paddle.to_tensor(self._avg_list)) self._scale, self._zero_point = self.cal_scales_zero_points() diff --git a/paddleslim/quant/observers/emd.py b/paddleslim/quant/observers/emd.py index 02bea81a588158..8dea968e70cc86 100644 --- a/paddleslim/quant/observers/emd.py +++ b/paddleslim/quant/observers/emd.py @@ -85,6 +85,9 @@ def cal_min_max(self, inputs): def cal_thresholds(self): """ Compute thresholds for MAX function. """ + if self._scale is not None: + self._zero_point = 0 + return self._min, self._max = self._emd_min, self._emd_max self._scale, self._zero_point = self.cal_scales_zero_points() diff --git a/paddleslim/quant/observers/mse.py b/paddleslim/quant/observers/mse.py index 6deab94ad833cf..74433576e2c979 100644 --- a/paddleslim/quant/observers/mse.py +++ b/paddleslim/quant/observers/mse.py @@ -82,6 +82,9 @@ def cal_min_max(self, inputs): def cal_thresholds(self): """ Compute thresholds for MAX function. """ + if self._scale is not None: + self._zero_point = 0 + return self._min, self._max = self._mse_min, self._mse_max self._scale, self._zero_point = self.cal_scales_zero_points()