-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
python/mxnet/contrib/quantization.py
Outdated
@@ -499,6 +499,9 @@ def quantize_model(sym, arg_params, aux_params, | |||
if quantized_dtype not in ('int8', 'uint8'): | |||
raise ValueError('unknown quantized_dtype %s received,' | |||
' expected `int8` or `uint8`' % quantized_dtype) | |||
if quantized_dtype == 'uint8' and ctx != cpu(): | |||
raise ValueError('currently gpu does not support uint8 quantization,' | |||
' please set quantized_dtype to int8') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the something like below?
“Currently, uint8 quantization is only supported by CPU, please switch to the context of CPU or int8 data type for GPU"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, changed:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for the quick fix :)
@reminisce Can you help take a look at this? Thanks:) |
@mxnet-label-bot add [pr-awaiting-review, Quantization] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
python/mxnet/contrib/quantization.py
Outdated
@@ -499,6 +499,9 @@ def quantize_model(sym, arg_params, aux_params, | |||
if quantized_dtype not in ('int8', 'uint8'): | |||
raise ValueError('unknown quantized_dtype %s received,' | |||
' expected `int8` or `uint8`' % quantized_dtype) | |||
if quantized_dtype == 'uint8' and ctx != cpu(): | |||
raise ValueError('currently, uint8 quantization is only supported by CPU,' | |||
' please switch to the context of CPU or int8 data type for GPU') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to add this error to backend like in the case for MKLDNN with int8 so that we dont have to add error handling to other frontends when we support quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently, only python frontend support quantization and in fact calibration progress will not use backend specific quantized operator. So I think it's good to add error message in this place currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In QuantizeCompute
(quantize-inl.h
) you can check if std::is_same<xpu,gpu>::value and check for param.out_type and throw exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this modification can work since infer type error mxnet.base.MXNetError: [02:07:55] /home/ubuntu/experimentals/1.4_release/src/operator/quantization/../tensor/matrix_op-inl.h:250: Check failed: src.type_flag_ == ret.type_flag_ (3 vs. 5)
will occur before QuantizeCompute
and we cannot get the ctx information during infer
stage. So I think it's good to interrupt this action during the calibration stage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt that called from the forward pass of quantized_conv ? The quantize forward pass should execute before this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add check src_type
in quantized_conv.cu
, please take a review again.
@rajeshii Thanks for the quick turnaround. Could you please look into comments by @anirudh2290 |
@@ -76,6 +76,9 @@ class QuantizedCuDNNConvOp { | |||
if (param_.pad.ndim() == 0U) param_.pad = mshadow::Shape2(0, 0); | |||
N = 0, H = 2, W = 3, C = 1; | |||
src_type_ = mshadow::DataType<SrcType>::kCudnnFlag; | |||
CHECK_EQ(src_type_, 5U) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the 5U here?
return | ||
elif qdtype == 'uint8' and is_test_for_gpu(): | ||
print('skipped testing quantize_model for gpu uint8 since it is not supported yet') | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add else clause.
@anirudh2290 @TaoLv is this PR good to go? |
@@ -110,6 +110,9 @@ class QuantizedCuDNNConvOp { | |||
const TShape& fshape = filter.shape_; | |||
const TShape& oshape = out.shape_; | |||
|
|||
CHECK_EQ(data.type_flag_, mshadow::kInt8) | |||
<< "currently, uint8 quantization is only supported by CPU, " | |||
"please switch to the context of CPU or int8 data type for GPU."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add inside quantize-inl.h, this way it will return an error message even for networks without this op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, added:)
This reverts commit ab68668.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
@mxnet-label-bot update [Quantization, pr-awaiting-merge] |
* enhance gpu quantization * fix test and improve error message * add check srctype to quantized_conv.cu * improve infer type * fix lint * add dtype check in quantize * revert check in python level and quantized_conv * Revert "add dtype check in quantize" This reverts commit ab68668. * add dtype check in quantize * fix quantize test case
* enhance gpu quantization * fix test and improve error message * add check srctype to quantized_conv.cu * improve infer type * fix lint * add dtype check in quantize * revert check in python level and quantized_conv * Revert "add dtype check in quantize" This reverts commit ab68668. * add dtype check in quantize * fix quantize test case
* enhance gpu quantization * fix test and improve error message * add check srctype to quantized_conv.cu * improve infer type * fix lint * add dtype check in quantize * revert check in python level and quantized_conv * Revert "add dtype check in quantize" This reverts commit ab68668. * add dtype check in quantize * fix quantize test case
Description
"Fixes #14092"
As #14092 mentioned, GPU only supports int8 quantization and does not support uint8. So, add an error message in quantize model function.
@reminisce
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments