diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9508f1e649fe..be497e5e8359 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -236,6 +236,7 @@ List of Contributors * [Zhennan Qin](/~https://github.com/ZhennanQin) * [Zhiyuan Huang](/~https://github.com/huangzhiyuan) * [Zak Jost](/~https://github.com/zjost) +* [Shoubhik Bhattacharya](/~https://github.com/shoubhik) * [Zach Kimberg](/~https://github.com/zachgk) Label Bot diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index dd433e41f69c..e8e2cd90b86c 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -84,6 +84,10 @@ by keep zero centered for the quantized value: .set_attr_parser(ParamParser) .set_num_inputs(3) .set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_range", "max_range"}; + }) .set_attr("FInferShape", DequantizeShape) .set_attr("FInferType", DequantizeType) .set_attr("FInferStorageType", DequantizeStorageType) diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index 4807226e464c..43682383b0d6 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -61,6 +61,10 @@ inference accuracy. .set_attr_parser(ParamParser) .set_num_inputs(3) .set_num_outputs(3) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "min_range", "max_range"}; + }) .set_attr("FInferShape", QuantizeShape) .set_attr("FInferType", RequantizeType) .set_attr("FInferStorageType", RequantizeStorageType) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index 2761e77fb0c1..3c8cc4234e54 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -63,19 +63,45 @@ def test_quantize_float32_to_int8(): @with_seed() def test_dequantize_int8_to_float32(): + + def get_test_data(real_range, qdata_np): + qdata = mx.nd.array(qdata_np, dtype=np.int8) + min_range = mx.nd.array([-real_range], dtype=np.float32) + max_range = mx.nd.array([real_range], dtype=np.float32) + return qdata, min_range, max_range + + def baseline_dequantization(qdata, real_range, qdata_np): + quantized_range = 127.0 + scale = real_range / quantized_range + data_np = qdata_np * scale + return data_np + + def test_nd_array_dequantization(qdata, min_range, max_range, expected_result): + data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result) + + def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result): + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, + sym_max_range, out_type='float32') + out = dequant.bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, 'max_range':max_range}) + data = out.forward()[0] + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result) + + real_range = 402.3347 shape = rand_shape_nd(4) qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) - qdata = mx.nd.array(qdata_np, dtype=np.int8) - real_range = 402.3347 - min_range = mx.nd.array([-real_range], dtype=np.float32) - max_range = mx.nd.array([real_range], dtype=np.float32) - data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') - quantized_range = 127.0 - scale = real_range / quantized_range - assert data.dtype == np.float32 - data_np = qdata_np * scale - assert_almost_equal(data.asnumpy(), data_np) - + qdata, min_range, max_range = get_test_data(real_range, qdata_np) + expected_result = baseline_dequantization(qdata, real_range, qdata_np) + # test nd array implementation. + test_nd_array_dequantization(qdata, min_range, max_range, expected_result) + # test symbolic api implementaion. + test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result) @with_seed() def test_requantize_int32_to_int8(): @@ -124,7 +150,41 @@ def check_requantize(shape, min_calib_range=None, max_calib_range=None): assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + if min_calib_range is None or max_calib_range is None: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range) + out = requant.bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + else: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range, + min_calib_range, max_calib_range) + out = requant.bind(ctx=mx.current_context(), args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + # test with symbol API. + check_requantize_with_symbol((3, 4, 10, 10)) + check_requantize_with_symbol((32, 3, 23, 23)) + check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + # Test with nd array API check_requantize((3, 4, 10, 10)) check_requantize((32, 3, 23, 23)) check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0)