Skip to content

Commit

Permalink
[MKLDNN]Refactor requantize to speed up execution (apache#14608)
Browse files Browse the repository at this point in the history
* Refactor requantize

* fix ci

* Fix CI

* Fix ci
  • Loading branch information
ZhennanQin authored and haohuw committed Jun 23, 2019
1 parent a615e1d commit 56c1308
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 129 deletions.
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<N
}

// Write output min/max
auto out_type = GetOutputType(param_);
auto out_type = GetQuantizeOutputType(param_);
if (out_type == mshadow::kUint8) {
quantized_range = kUint8Range;
*outputs[1].data().dptr<float>() = data_min;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
MKLDNNStream::Get()->Submit();
Stream<cpu> *s = ctx.get_stream<cpu>();
const size_t num_inputs = param.no_bias ? 2 : 3;
mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(s, 1,
out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
in_data[num_inputs].data().dptr<float>(),
in_data[num_inputs+1].data().dptr<float>(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
}

Stream<cpu> *s = ctx.get_stream<cpu>();
mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(s, 1,
min_output_ptr, max_output_ptr, &min_data, &max_data, &min_weight, &max_weight);

bool is_train = false;
Expand Down
91 changes: 46 additions & 45 deletions src/operator/quantization/mkldnn/mkldnn_requantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
namespace mxnet {
namespace op {

template <typename DstType>
static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -45,7 +46,6 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
using red::limits::MaxValue;
using red::limits::MinValue;
typedef int32_t SrcDType;
typedef int8_t DstDType;
// check shapes
size_t i_dim = inputs[0].shape().ndim();
size_t o_dim = outputs[0].shape().ndim();
Expand All @@ -56,12 +56,21 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
*inputs[2].data().dptr<float>());
float first_scale = first_real_range / first_quantized_range;
float second_real_range = real_range;
float second_quantized_range = MinAbs(MaxValue<DstDType>(),
MinValue<DstDType>());
float second_quantized_range = 0.f;
if (std::is_same<DstType, int8_t>::value) {
second_quantized_range = MinAbs(MaxValue<DstType>(), MinValue<DstType>());
*outputs[1].data().dptr<float>() = -second_real_range;
*outputs[2].data().dptr<float>() = second_real_range;
} else if (std::is_same<DstType, uint8_t>::value) {
second_quantized_range = MaxValue<DstType>();
*outputs[1].data().dptr<float>() = 0.f;
*outputs[2].data().dptr<float>() = second_real_range;
} else {
LOG(FATAL) << "Unsupported requantize output type";
}
float second_scale = second_quantized_range / second_real_range;
float scale = first_scale * second_scale;
*outputs[1].data().dptr<float>() = -second_real_range;
*outputs[2].data().dptr<float>() = second_real_range;

primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
Expand All @@ -82,7 +91,7 @@ static void MKLDNNRequantizeForwardKer(const nnvm::NodeAttrs& attrs,
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
}
auto o_desc = mkldnn::memory::desc(i_dims,
(mkldnn::memory::data_type)data_type_enum<DstDType>::type,
(mkldnn::memory::data_type)data_type_enum<DstType>::type,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
Expand All @@ -99,55 +108,47 @@ static void MKLDNNRequantizeForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
using red::limits::MaxValue;
using red::limits::MinValue;
typedef int32_t SrcDType;
typedef int8_t DstDType;
Stream<cpu> *s = ctx.get_stream<cpu>();
const RequantizeParam& param = nnvm::get<RequantizeParam>(attrs.parsed);
float real_range;
// Model is calibrated
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
real_range =
MaxAbs(param.min_calib_range.value(), param.max_calib_range.value());
MKLDNNRequantizeForwardKer(attrs, ctx, inputs, req, outputs, real_range);
// Model is not calibrated
} else {
mxnet::TShape src_shape, dst_shape;
const size_t actual_float_size = sizeof(float);
const size_t actual_quantized_size = sizeof(SrcDType);
const size_t temp_reduce_size = ConfigReduce<cpu, SrcDType>(s,
inputs[0].shape(), mxnet::TShape(1, 1), &src_shape, &dst_shape);
Tensor<cpu, 1, char> temp_space =
ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(2*actual_float_size+2*actual_quantized_size+temp_reduce_size), s);
Tensor<cpu, 1, float> actual_min_float(
reinterpret_cast<float*>(temp_space.dptr_), Shape1(1), s);
Tensor<cpu, 1, float> actual_max_float(
reinterpret_cast<float*>(temp_space.dptr_) + 1, Shape1(1), s);
const int dev_id = ctx.run_ctx.ctx.dev_id;
TBlob actual_min_quantized(reinterpret_cast<SrcDType*>(
temp_space.dptr_ + 8), Shape1(1), cpu::kDevMask, dev_id);
TBlob actual_max_quantized(reinterpret_cast<SrcDType*>(
temp_space.dptr_ + 8) + 1, Shape1(1), cpu::kDevMask, dev_id);
Tensor<cpu, 1, char> workspace(
temp_space.dptr_+2*actual_float_size+2*actual_quantized_size,
Shape1(temp_reduce_size), s);
broadcast::Reduce<red::minimum, 2, SrcDType, mshadow::op::identity>(
s, actual_min_quantized.reshape(dst_shape), kWriteTo,
workspace, inputs[0].Reorder2Default().data().reshape(src_shape));
Kernel<QuantizedToFloatStruct, cpu>::Launch(s, 1,
actual_min_float.dptr_, actual_min_quantized.dptr<SrcDType>(),
inputs[1].Reorder2Default().data().dptr<float>(),
inputs[2].Reorder2Default().data().dptr<float>());
broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
s, actual_max_quantized.reshape(dst_shape), kWriteTo,
workspace, inputs[0].Reorder2Default().data().reshape(src_shape));
Kernel<QuantizedToFloatStruct, cpu>::Launch(s, 1,
actual_max_float.dptr_, actual_max_quantized.dptr<SrcDType>(),
inputs[1].Reorder2Default().data().dptr<float>(),
inputs[2].Reorder2Default().data().dptr<float>());

real_range = MaxAbs(*actual_min_float.dptr_, *actual_max_float.dptr_);
MKLDNNRequantizeForwardKer(attrs, ctx, inputs, req, outputs, real_range);
NDArray in_buffer = inputs[0].Reorder2Default();
auto in_ptr = in_buffer.data().dptr<SrcDType>();
auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
SrcDType data_min = MaxValue<SrcDType>();
SrcDType data_max = MinValue<SrcDType>();
std::vector<SrcDType> data_maxs(nthreads, data_max);
std::vector<SrcDType> data_mins(nthreads, data_min);
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(in_buffer.shape().Size()); i++) {
int tid = omp_get_thread_num();
if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i];
if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i];
}
for (index_t i = 0; i < nthreads; i++) {
if (data_maxs[i] > data_max) data_max = data_maxs[i];
if (data_mins[i] < data_min) data_min = data_mins[i];
}
float src_range = MinAbs(MinValue<SrcDType>(), MaxValue<SrcDType>());
SrcDType data_range = MaxAbs(data_min, data_max);
float data_scale = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
real_range = data_range * data_scale / src_range;
}
auto out_type = GetQuantizeOutputType(param);
if (out_type == mshadow::kUint8) {
MKLDNNRequantizeForwardKer<uint8_t>(attrs, ctx, inputs, req, outputs, real_range);
} else if (out_type == mshadow::kInt8) {
MKLDNNRequantizeForwardKer<int8_t>(attrs, ctx, inputs, req, outputs, real_range);
} else {
LOG(FATAL) << "mkldnn requantize op only supports int8 and uint8 as output type";
}
}

Expand Down
78 changes: 53 additions & 25 deletions src/operator/quantization/quantization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,39 +127,31 @@ MSHADOW_XINLINE void RequantizeManyInNewRange(size_t count, T2* output, const T1
* \brief Get the scaling factor for converting type T to float.
*/
template<typename T>
MSHADOW_XINLINE float FloatForOneQuantizedLevel(float range_min, float range_max) {
MSHADOW_XINLINE float FloatForOneQuantizedLevel(float range_min, float range_max, bool all_sign) {
using mshadow::red::limits::MinValue;
using mshadow::red::limits::MaxValue;
const int64_t highest = static_cast<int64_t>(MaxValue<T>());
const int64_t lowest = static_cast<int64_t>(MinValue<T>());
const float float_for_one_quantized_level =
(range_max - range_min) / (highest - lowest);
return float_for_one_quantized_level;
float range_data = MaxAbs(range_min, range_max);
float range_T = all_sign ? MinAbs(MinValue<T>(), MaxValue<T>()) : MaxValue<T>();
return range_data / range_T;
}

template <typename TA, typename TB, typename TC>
MSHADOW_XINLINE void QuantizationRangeForMultiplication(float min_a, float max_a,
float min_b, float max_b,
float* min_c, float* max_c) {
using mshadow::red::limits::MinValue;
MSHADOW_XINLINE void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
float max_b, float *min_c, float *max_c,
bool all_sign) {
using mshadow::red::limits::MaxValue;
const float a_float_for_one_quant_level =
FloatForOneQuantizedLevel<TA>(min_a, max_a);
const float b_float_for_one_quant_level =
FloatForOneQuantizedLevel<TB>(min_b, max_b);

const int64_t c_highest =
static_cast<int64_t>(MaxValue<TC>());
const int64_t c_lowest =
static_cast<int64_t>(MinValue<TC>());
using mshadow::red::limits::MinValue;
const float a_float_for_one_quant_level = FloatForOneQuantizedLevel<TA>(min_a, max_a, all_sign);
const float b_float_for_one_quant_level = FloatForOneQuantizedLevel<TB>(min_b, max_b, all_sign);
const float range_c =
MinAbs(static_cast<int64_t>(MinValue<TC>()), static_cast<int64_t>(MaxValue<TC>()));
const float c_float_for_one_quant_level =
a_float_for_one_quant_level * b_float_for_one_quant_level;

*min_c = c_float_for_one_quant_level * c_lowest;
*max_c = c_float_for_one_quant_level * c_highest;
a_float_for_one_quant_level * b_float_for_one_quant_level;
*max_c = c_float_for_one_quant_level * range_c;
*min_c = -*max_c;
}

struct QuantizationRangeForMultiplicationStruct {
struct QuantizationRangeForS8S8MultiplicationStruct {
MSHADOW_XINLINE static void Map(int i,
float *min_c,
float *max_c,
Expand All @@ -168,7 +160,20 @@ struct QuantizationRangeForMultiplicationStruct {
const float *min_b,
const float *max_b) {
QuantizationRangeForMultiplication<int8_t, int8_t, int32_t>(
min_a[i], max_a[i], min_b[i], max_b[i], min_c, max_c);
min_a[i], max_a[i], min_b[i], max_b[i], min_c, max_c, true);
}
};

struct QuantizationRangeForS8U8MultiplicationStruct {
MSHADOW_XINLINE static void Map(int i,
float *min_c,
float *max_c,
const float *min_a,
const float *max_a,
const float *min_b,
const float *max_b) {
QuantizationRangeForMultiplication<int8_t, uint8_t, int32_t>(
min_a[i], max_a[i], min_b[i], max_b[i], min_c, max_c, false);
}
};

Expand All @@ -186,6 +191,29 @@ inline size_t ConfigReduce(mshadow::Stream<xpu>* s,
return broadcast::ReduceWorkspaceSize<NDim, DType>(s, *dst_shape, kWriteTo, *src_shape);
}

enum QuantizeOutType { kAuto = 0, kInt8, kUint8 };

template<typename Param>
static mshadow::TypeFlag GetQuantizeOutputType(const Param &param) {
auto out_type = mshadow::kInt8;
if (param.out_type == QuantizeOutType::kAuto) {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
if (param.min_calib_range.value() >= 0.0) {
out_type = mshadow::kUint8;
} else {
out_type = mshadow::kInt8;
}
}
} else if (param.out_type == QuantizeOutType::kInt8) {
out_type = mshadow::kInt8;
} else if (param.out_type == QuantizeOutType::kUint8) {
out_type = mshadow::kUint8;
} else {
LOG(FATAL) << "Unsupported out_type in params: " <<param.out_type;
}
return out_type;
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZATION_UTILS_H_
3 changes: 2 additions & 1 deletion src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ Graph QuantizeGraph(Graph &&src) {
NodePtr requantize_node = Node::Create();
requantize_node->attrs.op = Op::Get("_contrib_requantize");
requantize_node->attrs.name = "requantize_" + node->attrs.name;
requantize_node->attrs.dict["out_type"] = quantized_dtype;
if (requantize_node->op()->attr_parser != nullptr) {
requantize_node->op()->attr_parser(&(requantize_node->attrs));
}
Expand Down Expand Up @@ -398,7 +399,7 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second);
node->op()->attr_parser(&(node->attrs));
const QuantizeV2Param& param = nnvm::get<QuantizeV2Param>(node->attrs.parsed);
if (param.out_type == QuantizeV2Param::OutType::kUint8 &&
if (param.out_type == QuantizeOutType::kUint8 &&
param.min_calib_range.value() < 0.0f) {
LOG(WARNING) << "Calibration statistics indicates that node `" << node->attrs.name
<< "` has negative input, consider use `auto` or `int8` as out_type";
Expand Down
33 changes: 6 additions & 27 deletions src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@ namespace mxnet {
namespace op {

struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
enum OutType { kAuto = 0, kInt8, kUint8 };
int out_type;
dmlc::optional<float> min_calib_range;
dmlc::optional<float> max_calib_range;
DMLC_DECLARE_PARAMETER(QuantizeV2Param) {
DMLC_DECLARE_FIELD(out_type)
.add_enum("auto", kAuto)
.add_enum("int8", kInt8)
.add_enum("uint8", kUint8)
.set_default(kInt8)
.add_enum("auto", QuantizeOutType::kAuto)
.add_enum("int8", QuantizeOutType::kInt8)
.add_enum("uint8", QuantizeOutType::kUint8)
.set_default(QuantizeOutType::kInt8)
.describe("Output data type. `auto` can be specified to automatically determine output type "
"according to min_calib_range.");
DMLC_DECLARE_FIELD(min_calib_range)
Expand All @@ -61,26 +60,6 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
}
};

static mshadow::TypeFlag GetOutputType(const QuantizeV2Param &param) {
auto out_type = mshadow::kInt8;
if (param.out_type == QuantizeV2Param::OutType::kAuto) {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
if (param.min_calib_range.value() >= 0.0) {
out_type = mshadow::kUint8;
} else {
out_type = mshadow::kInt8;
}
}
} else if (param.out_type == QuantizeV2Param::OutType::kInt8) {
out_type = mshadow::kInt8;
} else if (param.out_type == QuantizeV2Param::OutType::kUint8) {
out_type = mshadow::kUint8;
} else {
LOG(FATAL) << "Unsupported out_type in params: " <<param.out_type;
}
return out_type;
}

// quantize float to uint8_t
struct quantize_v2_unsigned {
template <typename DstDType, typename SrcDType>
Expand Down Expand Up @@ -143,7 +122,7 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
in_attrs->at(0) == mshadow::kInt8);
auto out_type = GetOutputType(param);
auto out_type = GetQuantizeOutputType(param);
if (out_type == mshadow::kUint8) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
} else if (out_type == mshadow::kInt8) {
Expand All @@ -170,7 +149,7 @@ class QuantizeV2Operator {
using mshadow::red::limits::MinValue;
Stream<xpu> *s = ctx.get_stream<xpu>();
const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs_.parsed);
auto out_type = GetOutputType(param);
auto out_type = GetQuantizeOutputType(param);
if (out_type == mshadow::kUint8 && std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/quantized_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class QuantizedCuDNNConvOp {
// of in_data[0] and in_data[1]. Need to rescale the min/max range of out_data
// based on the min/max ranges of in_data[0] and in_data[1].
const size_t num_inputs = param_.no_bias ? 2 : 3;
mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, gpu>::Launch(s, 1,
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, gpu>::Launch(s, 1,
out_data[1].dptr<float>(), out_data[2].dptr<float>(),
in_data[num_inputs].dptr<float>(), in_data[num_inputs+1].dptr<float>(),
in_data[num_inputs+2].dptr<float>(), in_data[num_inputs+3].dptr<float>());
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
Tensor<cpu, 1, float> max_weight =
in_data[num_inputs + quantized_fullc::kWeightMax].get<cpu, 1, float>(s);

Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1, min_output.dptr_,
Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(s, 1, min_output.dptr_,
max_output.dptr_, min_data.dptr_, max_data.dptr_, min_weight.dptr_, max_weight.dptr_);
if (!param.no_bias) {
Tensor<cpu, 1, int8_t> bias = in_data[fullc::kBias].get_with_shape<cpu, 1, int8_t>(
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/quantized_fully_connected.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs,
cmp_type,
CUBLAS_GEMM_DFALT));

Kernel<QuantizationRangeForMultiplicationStruct, gpu>::Launch(s, 1,
Kernel<QuantizationRangeForS8S8MultiplicationStruct, gpu>::Launch(s, 1,
outputs[1].dptr<float>(), outputs[2].dptr<float>(),
inputs[num_inputs].dptr<float>(), inputs[num_inputs+1].dptr<float>(),
inputs[num_inputs+2].dptr<float>(), inputs[num_inputs+3].dptr<float>());
Expand Down
Loading

0 comments on commit 56c1308

Please sign in to comment.