Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN based Quantized FullyConnected Operator and its fusion #14128

Merged
merged 30 commits into from
Mar 8, 2019

Conversation

ciyongch
Copy link
Contributor

Description

This PR added MKL-DNN based quantized FullyConnected operator via FComputeEx API, and changed MKL IGEMM based quantized FullyConnected operator to FCompute API.
The PR also added the subgraph implementation for both FullyConnected and quantized FullyConnected to provide more operator fusion on graph level (it's easier to extend other element-wise operator fusion in the future), the following patterns are supported currently: FullyConnected + relu, quantized FullyConnected + requantize, and quantized FullyConnected + requantize + dequantize.

@pengzhao-intel @TaoLv @ZhennanQin @zheng-da

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@ciyongch ciyongch requested a review from szha as a code owner February 12, 2019 08:33
@ciyongch ciyongch changed the title Stateful inner product MKLDNN based Quantized FullyConnected Operator and its fusion Feb 12, 2019
@pengzhao-intel
Copy link
Contributor

@KellenSunderland @reminisce for the review :)

@ankkhedia
Copy link
Contributor

@ciyongch Thanks for the review!

@mxnet-label-bot add [pr-awaiting-review, MKLDNN]

@marcoabreu marcoabreu added MKLDNN pr-awaiting-review PR is waiting for code review labels Feb 12, 2019
@@ -245,6 +254,10 @@ def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")

def _init_quantized_weight(self, _, arr):
_arr = random.randint(-127, 127, dtype='int32').asnumpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ummm, seems need extend randint to support dtype='int8'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, int8 dtype is limitied to current randint.

src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h Outdated Show resolved Hide resolved
this->fwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
fwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight), *this->out));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there's nothing to do in the else clause if there's one, so I didn't put else here.
But I can change this piece of code in another way (which is equivalent to current one)

if (bias != null) {
  ....
} else {
  if (this->fwd_ == nullptr) {
    ....
  }
}

@@ -265,8 +313,11 @@ and max thresholds representing the threholds for quantizing the float32 output
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FComputeEx>("FComputeEx<cpu>",
QuantizedFullyConnectedForward<int8_t>)
.set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap this line into #if MSHADOW_USE_MKL == 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current implementation could give more information about quantizedFullyConnected dependencies on CPU for users. If put the Macro in op's attributes, only give 'not implemented' information.

@@ -0,0 +1,434 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename: mkldnn_fc.cc -> mkldnn_fully_connected.cc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already a mkldnn_fully_connected.cc in src/operator/nn/mkldnn/, so I choosed mkldnn_fc.cc for subgraph part to make this one different.

src/operator/subgraph/mkldnn/mkldnn_fc.cc Outdated Show resolved Hide resolved
bool disable_fc_relu;
};

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_FC, SgMKLDNNFCProperty);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we have several property name now, do we have any document to list and explain them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some updates to current subgraph API changes in PR #14113, suggest to update the property name and rules after it's merged.

@pengzhao-intel
Copy link
Contributor

@ZhennanQin could you help to review this PR?


struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
bool quantized;
bool fuse_requantize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is fuse_requantize necessary? Why don't directly check if min_calib_range and max_calib_range have value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This param was added to on pair with fuse_dequantize, but indeed this can be replaced by checking min/max_calib_range.
I will remove it.

struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
bool quantized;
bool fuse_requantize;
bool fuse_dequantize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about rename fuse_dequantize to float_output? End-user may doesn't care the fusion details, but want to get some straight forward meaning they cares.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but this param is only touch by subgraph but not end-user, instead end-user may care about the env of 'MXNET_DISABLE_MKLDNN_QFC_FUSE_DEQUANTIZE'. Will change this name.

if (full_param.mkldnn_param.with_relu) {
float scale = 1.0f;
float alpha = 0.0f;
float beta = 1.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add const

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@pengzhao-intel
Copy link
Contributor

@reminisce @zheng-da @anirudh2290 could you help to take a review?

MKLDNNFullyconSignature key(param);
key.AddSign(is_train);
key.AddSign(data);
key.AddSign(weight);
Copy link
Contributor

@ZhennanQin ZhennanQin Mar 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems quantized FC will call this function as well. Then output_scale should be a part of hashed key. Better to hash whole mkldnn_param.

Copy link
Contributor Author

@ciyongch ciyongch Mar 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point :)
Currently, this function is only called by normal FullyConnected or Quantized FC but not subgraph FC. While the output_scale or the whole mkldnn_param was only used by subgraph FC and they're useless here actually.

it = ins_ret.first;
MKLDNNFCFullParam full_param;
full_param.default_param = param;
full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not correct, should pass down real mkldnn_param.

Copy link
Contributor Author

@ciyongch ciyongch Mar 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this function is only called by normal FullyConnected and Quantized FC, while mkldnn_param was not used by these two Ops, so only default_param passed down from caller is enough.

#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (size_t i = 0; i < bias_size; ++i) {
quantized_bias_ptr[i] = bias_ptr[i] * bias_int32_rescale;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend to use mkldnn reorder instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will mkldnn reorder always better than this way (reorder may introduce overhead)? since bias is usually not a big array.

@@ -72,7 +79,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_type->size(), num_inputs * 3);
CHECK_EQ(out_type->size(), 3U);

for (size_t i = 0; i < num_inputs; ++i) {
for (size_t i = 1; i < num_inputs; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why skip i = 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input[0] will supports both INT8 and UINT8 here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then please add check for input[0] on INT8 or UINT8.

@ciyongch ciyongch force-pushed the stateful_inner_product branch from 006cb49 to b7d8324 Compare March 2, 2019 09:32
@pengzhao-intel
Copy link
Contributor

@TaoLv @ZhennanQin please take a review again. If there are no other comments, please help to approve and we need to merge the PR soon.

Copy link
Contributor

@pengzhao-intel pengzhao-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

The code has been verified by several internal cases and got the expected accuracy and performance.


/*!
* \file mkldnn_fully_connected-inl.h
* \brief
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incomplete doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean missing author info here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, along with the missing description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update the missing description for all the new files.


DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("enable quantization");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use consistent standard for user-facing documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, will update the description.

} catch (mkldnn::error &e) {
if (e.status == mkldnn_unimplemented &&
full_param.mkldnn_param.quantized) {
LOG(ERROR) << "AVX512-BW support or MKLDNN v0.18 is required for INT8 fully_connected.";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between LOG(ERROR) or LOG(FATAL)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(ERROR) works same as LOG(INFO), while LOG(FATAL) will throw an error with an error info and stop running.
LOG(ERROR) was used here to give the hint and then throw the original error later.

Copy link
Contributor

@ZhennanQin ZhennanQin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ciyongch
Copy link
Contributor Author

ciyongch commented Mar 4, 2019

Thanks for your review @szha @TaoLv @ZhennanQin @pengzhao-intel :)
I've updated the codes according to your comments, please help to check if there's any other comments.

Copy link
Member

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. Thank you for the contribution.

@ciyongch ciyongch force-pushed the stateful_inner_product branch from 95167fc to 28e1242 Compare March 5, 2019 02:54
@ciyongch
Copy link
Contributor Author

ciyongch commented Mar 5, 2019

Rebase to latest code base.

@ciyongch
Copy link
Contributor Author

ciyongch commented Mar 5, 2019

@pengzhao-intel @TaoLv
The code is updated to latest, all the comments are addressed, please help to check and merge if not other comments :)

@TaoLv
Copy link
Member

TaoLv commented Mar 6, 2019

@szha please confirm your concerns are fully addressed. Thanks.

@TaoLv
Copy link
Member

TaoLv commented Mar 7, 2019

@ciyongch please take a look at the conflicts. I would like to have this PR merged in 24 hours if there is no other comments or conflicts.

@ciyongch
Copy link
Contributor Author

ciyongch commented Mar 7, 2019

@TaoLv no problem.

@ciyongch ciyongch force-pushed the stateful_inner_product branch from 9b1fd6e to b8edfb5 Compare March 7, 2019 05:31
@TaoLv
Copy link
Member

TaoLv commented Mar 8, 2019

Thank you for the contribution @ciyongch. Merging now.

@TaoLv TaoLv merged commit 8668db7 into apache:master Mar 8, 2019
@ciyongch ciyongch deleted the stateful_inner_product branch March 13, 2019 02:26
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
…#14128)

* add MKL-DNN quantized innerproduct

* initial qfc with mkldnn

* Add MKL-DNN quantized_fully_connected

* refactor params order for fullyconnected

* update quantized_fully_connected unittest, force data to uint8 type temporary

* change mkl based quantized fully_connected to FCompute

* add check data type for mkldnn quantized_fc

* add fuse requantize and dequantize for mkldnn quantized fullyconnected

* add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected

* fix requantize scaling error

* add fallback when input data is int8

* fix mkl quantized fullyconnected index error

* update quantized fc test cases

* add subgraph node for mkldnn fullyconnected

* fix compiling and lint error

* clean and refactor code

* enable quantized_fc for imagenet

* cleanup code

* Fix StorageType error for non-mkldnn path

* fix pylint

* reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check

* rename variables and refactor codes according to comments

* add subgraph qfc tests and fix shape error

* remove fuse_requantize and change fuse_dequantize to enable_float_output.

* change to use mxnet::Tuple and update tests

* update description in file header

* update input0 type check for quantized FullyConnected

* fix conflit of mkl/test_subgraph.py

* retrigger CI

* retrigger CI due to hang
nswamy pushed a commit that referenced this pull request Apr 5, 2019
* add MKL-DNN quantized innerproduct

* initial qfc with mkldnn

* Add MKL-DNN quantized_fully_connected

* refactor params order for fullyconnected

* update quantized_fully_connected unittest, force data to uint8 type temporary

* change mkl based quantized fully_connected to FCompute

* add check data type for mkldnn quantized_fc

* add fuse requantize and dequantize for mkldnn quantized fullyconnected

* add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected

* fix requantize scaling error

* add fallback when input data is int8

* fix mkl quantized fullyconnected index error

* update quantized fc test cases

* add subgraph node for mkldnn fullyconnected

* fix compiling and lint error

* clean and refactor code

* enable quantized_fc for imagenet

* cleanup code

* Fix StorageType error for non-mkldnn path

* fix pylint

* reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check

* rename variables and refactor codes according to comments

* add subgraph qfc tests and fix shape error

* remove fuse_requantize and change fuse_dequantize to enable_float_output.

* change to use mxnet::Tuple and update tests

* update description in file header

* update input0 type check for quantized FullyConnected

* fix conflit of mkl/test_subgraph.py

* retrigger CI

* retrigger CI due to hang
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
…#14128)

* add MKL-DNN quantized innerproduct

* initial qfc with mkldnn

* Add MKL-DNN quantized_fully_connected

* refactor params order for fullyconnected

* update quantized_fully_connected unittest, force data to uint8 type temporary

* change mkl based quantized fully_connected to FCompute

* add check data type for mkldnn quantized_fc

* add fuse requantize and dequantize for mkldnn quantized fullyconnected

* add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected

* fix requantize scaling error

* add fallback when input data is int8

* fix mkl quantized fullyconnected index error

* update quantized fc test cases

* add subgraph node for mkldnn fullyconnected

* fix compiling and lint error

* clean and refactor code

* enable quantized_fc for imagenet

* cleanup code

* Fix StorageType error for non-mkldnn path

* fix pylint

* reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check

* rename variables and refactor codes according to comments

* add subgraph qfc tests and fix shape error

* remove fuse_requantize and change fuse_dequantize to enable_float_output.

* change to use mxnet::Tuple and update tests

* update description in file header

* update input0 type check for quantized FullyConnected

* fix conflit of mkl/test_subgraph.py

* retrigger CI

* retrigger CI due to hang
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
MKLDNN pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants