Skip to content

Commit

Permalink
Fix entropy for uint8 (apache#14150)
Browse files Browse the repository at this point in the history
* Fix entropy for uint8

* Add test

* Update test_quantization.py
  • Loading branch information
ZhennanQin authored and haohuw committed Jun 23, 2019
1 parent db1fd5e commit 4960d5a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
12 changes: 8 additions & 4 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _smooth_distribution(p, eps=0.0001):


# pylint: disable=line-too-long
def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
def _get_optimal_threshold(arr, quantized_dtype, num_bins=8001, num_quantized_bins=255):
"""Given a dataset, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Expand All @@ -285,6 +285,10 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))

if min_val >= 0 and quantized_dtype in ['auto', 'uint8']:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1

hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2
Expand Down Expand Up @@ -348,7 +352,7 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
# pylint: enable=line-too-long


def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logger=None):
def _get_optimal_thresholds(nd_dict, quantized_dtype, num_bins=8001, num_quantized_bins=255, logger=None):
"""Given a ndarray dict, find the optimal threshold for quantizing each value of the key."""
if stats is None:
raise ImportError('scipy.stats is required for running entropy mode of calculating'
Expand All @@ -364,7 +368,7 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg
for name in layer_names:
assert name in nd_dict
min_val, max_val, min_divergence, opt_th = \
_get_optimal_threshold(nd_dict[name], num_bins=num_bins,
_get_optimal_threshold(nd_dict[name], quantized_dtype, num_bins=num_bins,
num_quantized_bins=num_quantized_bins)
del nd_dict[name] # release the memory of ndarray
if min_val < 0:
Expand Down Expand Up @@ -521,7 +525,7 @@ def quantize_model(sym, arg_params, aux_params,
logger=logger)
logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples)
logger.info('Calculating optimal thresholds for quantization')
th_dict = _get_optimal_thresholds(nd_dict, logger=logger)
th_dict = _get_optimal_thresholds(nd_dict, quantized_dtype, logger=logger)
elif calib_mode == 'naive':
th_dict, num_examples = _collect_layer_output_min_max(
mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples,
Expand Down
20 changes: 11 additions & 9 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,10 +713,11 @@ def test_optimal_threshold_adversarial_case():
# The worst case for the optimal_threshold function is when the values are concentrated
# at one edge: [0, 0, ..., 1000]. (histogram)
# We want to make sure that the optimal threshold in this case is the max.
arr = np.array([2]*1000)
res = mx.contrib.quant._get_optimal_threshold(arr, num_quantized_bins=5)
# The threshold should be 2.
assert res[3] - 2 < 1e-5
arr = np.array([2] * 1000)
for dtype in ['uint8', 'int8', 'auto']:
res = mx.contrib.quant._get_optimal_threshold(arr, dtype, num_quantized_bins=5)
# The threshold should be 2.
assert res[3] - 2 < 1e-5


@with_seed()
Expand All @@ -728,11 +729,12 @@ def get_threshold(nd):
max_nd = mx.nd.max(nd)
return mx.nd.maximum(mx.nd.abs(min_nd), mx.nd.abs(max_nd)).asnumpy()

nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64)}
expected_threshold = get_threshold(nd_dict['layer1'])
th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict)
assert 'layer1' in th_dict
assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4)
for dtype in ['uint8', 'int8', 'auto']:
nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64)}
expected_threshold = get_threshold(nd_dict['layer1'])
th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict, dtype)
assert 'layer1' in th_dict
assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4)


if __name__ == "__main__":
Expand Down

0 comments on commit 4960d5a

Please sign in to comment.