-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support Quantized Fully Connected by INT8 GEMM #12922
Changes from 3 commits
49b189f
a2bfef4
91f1a9b
b8e8257
471a2dc
7b64226
1dbc106
babc764
d365b64
1010deb
818021d
b3df5a6
b3bf9f7
e537fc1
72b81d9
1f98f63
9171b1a
daf75e6
1ea0675
c87402e
28bf1c3
88562b9
54ee001
d2dde15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FULLY_CONNECTED_INL_H_ | ||
#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FULLY_CONNECTED_INL_H_ | ||
|
||
#include <vector> | ||
#include "quantization_utils.h" | ||
#include "../nn/fully_connected-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
namespace quantized_fc { | ||
enum QuantilizedfcOpResource {kTempSpace}; | ||
} | ||
|
||
struct QuantizedSumInitKernelWithBias { | ||
// init sum data with bias for matrix b (n) | ||
MSHADOW_XINLINE static void Map(int i, int32_t *out, | ||
const int8_t *bias, const float *min_out, | ||
const float *max_out, const float *min_bias, | ||
const float *max_bias) { | ||
typedef int32_t T1; | ||
typedef int8_t T2; | ||
using mshadow::red::limits::MinValue; | ||
using mshadow::red::limits::MaxValue; | ||
float float_for_one_out_quant = | ||
MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>()); | ||
float float_for_one_bias_quant = | ||
MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>()); | ||
if (float_for_one_out_quant != 0) { | ||
out[i] = bias[i] * float_for_one_bias_quant / | ||
float_for_one_out_quant; | ||
} else { | ||
LOG(INFO) << "WARNING: QuantizedBiasAddKernel float_for_one_out_quant is 0 !"; | ||
out[i] = 0; | ||
} | ||
} | ||
}; | ||
template<typename SrcType> | ||
void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data) { | ||
#if MSHADOW_USE_MKL == 1 | ||
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed); | ||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
size_t num_inputs = param.no_bias ? 2 : 3; | ||
CHECK_EQ(in_data.size(), num_inputs * 3); | ||
CHECK_EQ(out_data.size(), 3U); | ||
const NDArray& data = in_data[0]; | ||
const NDArray& weight = in_data[1]; | ||
const NDArray& out = out_data[0]; | ||
TShape dshape = data.shape(); | ||
TShape wshape = weight.shape(); | ||
TShape oshape = out.shape(); | ||
auto output_temp = out.data().dptr<int32_t>(); | ||
auto weight_temp = weight.data().dptr<SrcType>(); | ||
auto data_temp = data.data().dptr<SrcType>(); | ||
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); | ||
const float alpha = 1.0f; | ||
const float beta = 1.0f; | ||
const CBLAS_OFFSET offsetc = CblasFixOffset; | ||
const MKL_INT8 oa = 0; | ||
const MKL_INT8 ob = 0; | ||
MKL_INT32 oc = 0; | ||
const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim()); | ||
Stream<cpu> *s = ctx.get_stream<cpu>(); | ||
// cblas_gemm_s8u8s32 required first matrix must be uint8 | ||
// shift data from int8(from -128 to 127) to uint8 (from 0 to 255) | ||
int shift = 128; | ||
Tensor<cpu, 1, uint8_t> shiftdata = | ||
ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>( | ||
Shape1(m * k), s); | ||
#pragma omp parallel for num_threads(omp_threads) | ||
for (int i = 0; i < m * k; ++i) { | ||
shiftdata.dptr_[i] = data_temp[i] + shift; | ||
} | ||
|
||
Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1, | ||
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(), | ||
in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(), | ||
in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>()); | ||
if (!param.no_bias) { | ||
const NDArray& bias = in_data[2]; | ||
Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.data().dptr<int32_t>(), | ||
bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(), | ||
out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(), | ||
in_data[8].data().dptr<float>()); | ||
} else { | ||
#pragma omp parallel for num_threads(omp_threads) | ||
for (int i = 0; i < m * n; ++i) { | ||
output_temp[i] = 0; | ||
} | ||
} | ||
#pragma omp parallel for num_threads(omp_threads) | ||
for (int i = 0; i < n; ++i) { | ||
for (int j = 0; j < k; ++j) { | ||
output_temp[i] -= shift * weight_temp[i * k + j]; | ||
} | ||
} | ||
#pragma omp parallel for num_threads(omp_threads) | ||
for (int i = n; i < m * n; ++i) { | ||
output_temp[i] = output_temp[i % n]; | ||
} | ||
cblas_gemm_s8u8s32(CblasRowMajor, | ||
CblasNoTrans, | ||
CblasTrans, | ||
offsetc, | ||
m, | ||
n, | ||
k, | ||
alpha, | ||
shiftdata.dptr_, | ||
k, | ||
oa, | ||
weight.data().dptr<SrcType>(), | ||
k, | ||
ob, | ||
beta, | ||
out.data().dptr<int32_t>(), | ||
n, | ||
&oc); | ||
#else | ||
LOG(FATAL) << "s8u8s32 is only supported by MKL BLAS library"; | ||
#endif | ||
} | ||
|
||
NNVM_REGISTER_OP(_contrib_quantized_fully_connected) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", | ||
MKLDNNQuantizedFullyConnectedForward<int8_t>) | ||
.set_attr<FResourceRequest>("FResourceRequest", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; | ||
}); | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_FULLY_CONNECTED_INL_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
* \author Ziheng Jiang, Jun Wu | ||
*/ | ||
#include "../nn/fully_connected-inl.h" | ||
#include "./quantized_fully_connected-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
@@ -79,6 +80,20 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, | |
return true; | ||
} | ||
|
||
bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
*dispatch_mode = DispatchMode::kFCompute; | ||
if (dev_mask == mshadow::cpu::kDevMask) { | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
} | ||
for (size_t i = 0; i < out_attrs->size(); i++) | ||
(*out_attrs)[i] = kDefaultStorage; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just delete this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
return true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please consider using range for loops for readability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @larroy meant to use:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
} | ||
|
||
NNVM_REGISTER_OP(_contrib_quantized_fully_connected) | ||
.describe(R"code(Fully Connected operator for input, weight and bias data type of int8, | ||
and accumulates in type int32 for the output. For each argument, two more arguments of type | ||
|
@@ -112,6 +127,7 @@ and max thresholds representing the threholds for quantizing the float32 output | |
}) | ||
.set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape) | ||
.set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType) | ||
.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType) | ||
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) | ||
.add_argument("data", "NDArray-or-Symbol", "Input data.") | ||
.add_argument("weight", "NDArray-or-Symbol", "weight.") | ||
|
@@ -135,6 +151,5 @@ NNVM_REGISTER_OP(FullyConnected) | |
} | ||
return node; | ||
}); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -270,7 +270,7 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p | |
def test_quantized_fc(): | ||
def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): | ||
if mx.current_context().device_type != 'gpu': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to run this test on CPU in CI. Could we test to see if 'MKL' is in the env var 'BUILD_TAG' and run the test if it is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KellenSunderland good suggestion! Currently, the CI doesn't include Intel MKL library as BLAS library and @azai91 is working on adding it so that we can have a better coverage, such as batch_gemm, quantization FC, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pengzhao-intel Oh sorry, didn't realize that was the case. If the tests won't pass without full mkl installed and it's not there let's add this in a later PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pengzhao-intel do you mean the full MKL? We already use MKLML on CI. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lebeg yes, I mean full MKL. The MKLML doesn't have the INT8 GEMM now :) |
||
print('skipped testing quantized_fc on cpu since it is not supported yet') | ||
print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library') | ||
return | ||
elif qdtype == 'uint8' and is_test_for_gpu(): | ||
print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') | ||
|
@@ -283,16 +283,16 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): | |
fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null') | ||
if qdtype == 'uint8': | ||
data_low = 0.0 | ||
data_high = 127.0 | ||
data_high = 63.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason of changing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change data range from (-127,127) to (-63, 63) to avoid potential overflow when using igemm in some hardware platform |
||
else: | ||
data_low = -127.0 | ||
data_high = 127.0 | ||
data_low = -63.0 | ||
data_high = 63.0 | ||
fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, | ||
shape=data_shape).astype('int32') | ||
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, | ||
fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high, | ||
shape=arg_shapes[1]).astype('int32') | ||
if not no_bias: | ||
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, | ||
fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high, | ||
shape=arg_shapes[2]).astype('int32') | ||
output = fc_fp32_exe.forward()[0] | ||
|
||
|
@@ -335,6 +335,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): | |
check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) | ||
check_quantized_fc((32, 512, 2, 2), 100, False, qdtype) | ||
check_quantized_fc((32, 111, 2, 2), 100, False, qdtype) | ||
check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype) | ||
check_quantized_fc((256, 111, 2, 2), 800, False, qdtype) | ||
check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) | ||
check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) | ||
|
||
@with_seed() | ||
def test_quantized_flatten(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest move all the implementation to .cc file since it's only for CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed