-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN based Quantized FullyConnected Operator and its fusion #14128
Conversation
@KellenSunderland @reminisce for the review :) |
@ciyongch Thanks for the review! @mxnet-label-bot add [pr-awaiting-review, MKLDNN] |
@@ -245,6 +254,10 @@ def _init_weight(self, name, arr): | |||
"""Abstract method to Initialize weight.""" | |||
raise NotImplementedError("Must override it") | |||
|
|||
def _init_quantized_weight(self, _, arr): | |||
_arr = random.randint(-127, 127, dtype='int32').asnumpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ummm, seems need extend randint
to support dtype='int8'
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, int8
dtype is limitied to current randint
.
this->fwd = std::shared_ptr<mkldnn::inner_product_forward>( | ||
new mkldnn::inner_product_forward( | ||
fwd_pd, mkldnn::primitive::at(*this->data), | ||
mkldnn::primitive::at(*this->weight), *this->out)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need else
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, there's nothing to do in the else
clause if there's one, so I didn't put else
here.
But I can change this piece of code in another way (which is equivalent to current one)
if (bias != null) {
....
} else {
if (this->fwd_ == nullptr) {
....
}
}
@@ -265,8 +313,11 @@ and max thresholds representing the threholds for quantizing the float32 output | |||
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType) | |||
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType) | |||
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) | |||
.set_attr<FComputeEx>("FComputeEx<cpu>", | |||
QuantizedFullyConnectedForward<int8_t>) | |||
.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap this line into #if MSHADOW_USE_MKL == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current implementation could give more information about quantizedFullyConnected dependencies on CPU for users. If put the Macro in op's attributes, only give 'not implemented' information.
@@ -0,0 +1,434 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
filename: mkldnn_fc.cc -> mkldnn_fully_connected.cc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already a mkldnn_fully_connected.cc
in src/operator/nn/mkldnn/
, so I choosed mkldnn_fc.cc
for subgraph part to make this one different.
bool disable_fc_relu; | ||
}; | ||
|
||
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_FC, SgMKLDNNFCProperty); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we have several property name now, do we have any document to list and explain them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some updates to current subgraph API changes in PR #14113, suggest to update the property name and rules after it's merged.
@ZhennanQin could you help to review this PR? |
|
||
struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> { | ||
bool quantized; | ||
bool fuse_requantize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is fuse_requantize
necessary? Why don't directly check if min_calib_range and max_calib_range have value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This param was added to on pair with fuse_dequantize
, but indeed this can be replaced by checking min/max_calib_range
.
I will remove it.
struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> { | ||
bool quantized; | ||
bool fuse_requantize; | ||
bool fuse_dequantize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about rename fuse_dequantize
to float_output
? End-user may doesn't care the fusion details, but want to get some straight forward meaning they cares.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but this param is only touch by subgraph but not end-user, instead end-user may care about the env of 'MXNET_DISABLE_MKLDNN_QFC_FUSE_DEQUANTIZE'. Will change this name.
if (full_param.mkldnn_param.with_relu) { | ||
float scale = 1.0f; | ||
float alpha = 0.0f; | ||
float beta = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
@reminisce @zheng-da @anirudh2290 could you help to take a review? |
MKLDNNFullyconSignature key(param); | ||
key.AddSign(is_train); | ||
key.AddSign(data); | ||
key.AddSign(weight); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems quantized FC will call this function as well. Then output_scale should be a part of hashed key. Better to hash whole mkldnn_param.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point :)
Currently, this function is only called by normal FullyConnected or Quantized FC but not subgraph FC. While the output_scale
or the whole mkldnn_param
was only used by subgraph FC and they're useless here actually.
it = ins_ret.first; | ||
MKLDNNFCFullParam full_param; | ||
full_param.default_param = param; | ||
full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not correct, should pass down real mkldnn_param.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function is only called by normal FullyConnected and Quantized FC, while mkldnn_param
was not used by these two Ops, so only default_param
passed down from caller is enough.
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) | ||
for (size_t i = 0; i < bias_size; ++i) { | ||
quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend to use mkldnn reorder instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will mkldnn reorder always better than this way (reorder may introduce overhead)? since bias is usually not a big array.
@@ -72,7 +79,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, | |||
CHECK_EQ(in_type->size(), num_inputs * 3); | |||
CHECK_EQ(out_type->size(), 3U); | |||
|
|||
for (size_t i = 0; i < num_inputs; ++i) { | |||
for (size_t i = 1; i < num_inputs; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why skip i = 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input[0] will supports both INT8 and UINT8 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then please add check for input[0] on INT8 or UINT8.
006cb49
to
b7d8324
Compare
@TaoLv @ZhennanQin please take a review again. If there are no other comments, please help to approve and we need to merge the PR soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
The code has been verified by several internal cases and got the expected accuracy and performance.
|
||
/*! | ||
* \file mkldnn_fully_connected-inl.h | ||
* \brief |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incomplete doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean missing author info here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, along with the missing description
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will update the missing description for all the new files.
|
||
DMLC_DECLARE_PARAMETER(MKLDNNFCParam) { | ||
DMLC_DECLARE_FIELD(quantized).set_default(false) | ||
.describe("enable quantization"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use consistent standard for user-facing documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, will update the description.
} catch (mkldnn::error &e) { | ||
if (e.status == mkldnn_unimplemented && | ||
full_param.mkldnn_param.quantized) { | ||
LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between LOG(ERROR)
or LOG(FATAL)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(ERROR)
works same as LOG(INFO)
, while LOG(FATAL)
will throw an error with an error info and stop running.
LOG(ERROR)
was used here to give the hint and then throw the original error later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for your review @szha @TaoLv @ZhennanQin @pengzhao-intel :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now. Thank you for the contribution.
95167fc
to
28e1242
Compare
Rebase to latest code base. |
@pengzhao-intel @TaoLv |
@szha please confirm your concerns are fully addressed. Thanks. |
@ciyongch please take a look at the conflicts. I would like to have this PR merged in 24 hours if there is no other comments or conflicts. |
@TaoLv no problem. |
9b1fd6e
to
b8edfb5
Compare
Thank you for the contribution @ciyongch. Merging now. |
…#14128) * add MKL-DNN quantized innerproduct * initial qfc with mkldnn * Add MKL-DNN quantized_fully_connected * refactor params order for fullyconnected * update quantized_fully_connected unittest, force data to uint8 type temporary * change mkl based quantized fully_connected to FCompute * add check data type for mkldnn quantized_fc * add fuse requantize and dequantize for mkldnn quantized fullyconnected * add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected * fix requantize scaling error * add fallback when input data is int8 * fix mkl quantized fullyconnected index error * update quantized fc test cases * add subgraph node for mkldnn fullyconnected * fix compiling and lint error * clean and refactor code * enable quantized_fc for imagenet * cleanup code * Fix StorageType error for non-mkldnn path * fix pylint * reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check * rename variables and refactor codes according to comments * add subgraph qfc tests and fix shape error * remove fuse_requantize and change fuse_dequantize to enable_float_output. * change to use mxnet::Tuple and update tests * update description in file header * update input0 type check for quantized FullyConnected * fix conflit of mkl/test_subgraph.py * retrigger CI * retrigger CI due to hang
* add MKL-DNN quantized innerproduct * initial qfc with mkldnn * Add MKL-DNN quantized_fully_connected * refactor params order for fullyconnected * update quantized_fully_connected unittest, force data to uint8 type temporary * change mkl based quantized fully_connected to FCompute * add check data type for mkldnn quantized_fc * add fuse requantize and dequantize for mkldnn quantized fullyconnected * add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected * fix requantize scaling error * add fallback when input data is int8 * fix mkl quantized fullyconnected index error * update quantized fc test cases * add subgraph node for mkldnn fullyconnected * fix compiling and lint error * clean and refactor code * enable quantized_fc for imagenet * cleanup code * Fix StorageType error for non-mkldnn path * fix pylint * reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check * rename variables and refactor codes according to comments * add subgraph qfc tests and fix shape error * remove fuse_requantize and change fuse_dequantize to enable_float_output. * change to use mxnet::Tuple and update tests * update description in file header * update input0 type check for quantized FullyConnected * fix conflit of mkl/test_subgraph.py * retrigger CI * retrigger CI due to hang
…#14128) * add MKL-DNN quantized innerproduct * initial qfc with mkldnn * Add MKL-DNN quantized_fully_connected * refactor params order for fullyconnected * update quantized_fully_connected unittest, force data to uint8 type temporary * change mkl based quantized fully_connected to FCompute * add check data type for mkldnn quantized_fc * add fuse requantize and dequantize for mkldnn quantized fullyconnected * add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected * fix requantize scaling error * add fallback when input data is int8 * fix mkl quantized fullyconnected index error * update quantized fc test cases * add subgraph node for mkldnn fullyconnected * fix compiling and lint error * clean and refactor code * enable quantized_fc for imagenet * cleanup code * Fix StorageType error for non-mkldnn path * fix pylint * reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check * rename variables and refactor codes according to comments * add subgraph qfc tests and fix shape error * remove fuse_requantize and change fuse_dequantize to enable_float_output. * change to use mxnet::Tuple and update tests * update description in file header * update input0 type check for quantized FullyConnected * fix conflit of mkl/test_subgraph.py * retrigger CI * retrigger CI due to hang
Description
This PR added MKL-DNN based quantized FullyConnected operator via FComputeEx API, and changed MKL IGEMM based quantized FullyConnected operator to FCompute API.
The PR also added the subgraph implementation for both FullyConnected and quantized FullyConnected to provide more operator fusion on graph level (it's easier to extend other element-wise operator fusion in the future), the following patterns are supported currently:
FullyConnected + relu
,quantized FullyConnected + requantize
, andquantized FullyConnected + requantize + dequantize
.@pengzhao-intel @TaoLv @ZhennanQin @zheng-da
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments