Skip to content

Commit

Permalink
Enhance gpu quantization (apache#14094)
Browse files Browse the repository at this point in the history
* enhance gpu quantization

* fix test and improve error message

* add check srctype to quantized_conv.cu

* improve infer type

* fix lint

* add dtype check in quantize

* revert check in python level and quantized_conv

* Revert "add dtype check in quantize"

This reverts commit ab68668.

* add dtype check in quantize

* fix quantize test case
  • Loading branch information
jitMatrix authored and TaoLv committed Mar 6, 2019
1 parent f2497aa commit 49d7fc6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/operator/quantization/quantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ void QuantizeCompute(const nnvm::NodeAttrs& attrs,

const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
if (param.out_type == mshadow::kUint8) {
if (std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
}
Kernel<quantize_unsigned, xpu>::Launch(s, outputs[0].Size(),
outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
Stream<xpu> *s = ctx.get_stream<xpu>();
const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
auto out_type = GetOutputType(param);
if (out_type == mshadow::kUint8 && std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
}

if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def get_fp32_sym_with_multiple_outputs(length=1):
@with_seed()
def test_quantize_model():
def check_quantize_model(qdtype):
if is_test_for_native_cpu():
print('skipped testing quantize_model for native cpu since it is not supported yet')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing quantize_model for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantize_model for gpu uint8 since it is not supported yet')
return

def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4894,11 +4894,11 @@ def test_quantization_op():
min0 = mx.nd.array([0.0])
max0 = mx.nd.array([1.0])
a = mx.nd.array([[0.1392, 0.5928], [0.6027, 0.8579]])
qa, min1, max1 = mx.nd.contrib.quantize(a, min0, max0, out_type='uint8')
qa, min1, max1 = mx.nd.contrib.quantize(a, min0, max0, out_type='int8')
a_ = mx.nd.contrib.dequantize(qa, min1, max1, out_type='float32')

qa_real = mx.nd.array([[35, 151], [154, 219]])
a_real = mx.nd.array([[0.13725491, 0.59215689], [0.60392159, 0.8588236]])
qa_real = mx.nd.array([[18, 75], [77, 109]])
a_real = mx.nd.array([[0.14173228, 0.5905512], [0.6062992, 0.8582677]])

assert same(qa.asnumpy(), qa_real.asnumpy())
assert same(a_.asnumpy(), a_real.asnumpy())
Expand Down

0 comments on commit 49d7fc6

Please sign in to comment.