Skip to content

Commit

Permalink
fix some parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Mar 24, 2022
1 parent ff4ba1a commit 0857203
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
7 changes: 3 additions & 4 deletions python/paddle/fluid/contrib/slim/quantization/adaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,15 @@ def run_adaround(data_loader,
exe,
scope,
place,
quantized_op_output_name_dict,
quantized_op_pairs,
weight_op_pairs,
scale_dict,
num_iterations=1000,
lr=0.001,
fast_mode=True):
fetch_op_name = fetch_list[0].name
final_weight_tensor_quant_dict = {}
for weight_var_name, quant_op_out_name in quantized_op_output_name_dict.items(
):
for weight_var_name, quant_op_out_name in quantized_op_pairs.items():
_logger.info('Start adaround op: {}'.format(weight_var_name))
weight_op_type = weight_op_pairs[weight_var_name]
# get scale and weight tensor
Expand Down Expand Up @@ -305,6 +304,6 @@ def run_adaround(data_loader,
del adaround

# update adarounded calibrated weights
for weight_var_name in quantized_op_output_name_dict.keys():
for weight_var_name in quantized_op_pairs.keys():
set_variable_data(scope, place, weight_var_name,
final_weight_tensor_quant_dict[weight_var_name])
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(self,
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
round_type='round',
train_iterations=1000,
learning_rate=0.001,
is_full_quantize=False,
bias_correction=False,
Expand Down Expand Up @@ -182,11 +181,9 @@ def __init__(self,
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
round_type(str, optional): The method of converting the quantized weights
value from float to int. Currently supports ['round', 'adaround'] methods.
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number.
train_iterations(flota, optional): The number of training iter, used to
calibrate the adaptive rounding method, when round_type='adaround'.
learning_rate(flota, optional): The learning rate of adaround method.
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
Expand Down Expand Up @@ -265,7 +262,6 @@ def __init__(self,
]
assert round_type in ['adaround', 'round']
self._round_type = round_type
self._train_iterations = train_iterations
self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
Expand Down Expand Up @@ -446,10 +442,10 @@ def _adaround_apply(self):
self._executor,
self._scope,
self._place,
self._quantized_op_output_name_dict,
self._quantized_op_pairs,
self._weight_op_pairs,
scale_dict,
num_iterations=self._train_iterations,
num_iterations=self._batch_nums,
lr=self._learning_rate)

def save_quantized_model(self,
Expand Down Expand Up @@ -534,7 +530,7 @@ def _collect_target_varnames(self):
'''
# TODO(juncaipeng), consider the name_scope of skip_quant
_logger.info("Collect quantized variable names ...")
self._quantized_op_output_name_dict = {}
self._quantized_op_pairs = {}

def collect_var_name(var_name_list, persistable_var_names, op_type):
for var_name in var_name_list:
Expand Down Expand Up @@ -564,7 +560,7 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
for out_var_name in _get_op_output_var_names(op):
for in_var_name in _get_op_input_var_names(op):
if in_var_name in persistable_var_names:
self._quantized_op_output_name_dict[
self._quantized_op_pairs[
in_var_name] = out_var_name
# For other op, only sample output scale
elif op_type in self._out_scale_op_list:
Expand Down Expand Up @@ -984,7 +980,7 @@ def analysis_and_save_info(op_node, out_var_name):
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist")

elif self._algo in ["avg", "abs_max", "mse"]:
elif self._algo in ["avg", "abs_max", "mse", "emd"]:
save_info(op_node, out_var_name, self._quantized_threshold,
"out_threshold", "post_" + str(self._algo))
save_info(
Expand Down

0 comments on commit 0857203

Please sign in to comment.