Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN]Improve quantizeV2 and dequantize latency #14641

Merged
merged 9 commits into from
Apr 18, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 38 additions & 24 deletions src/operator/quantization/dequantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,6 @@ struct dequantize_zero_centered {
}
};

template<typename xpu>
void DequantizeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
using mshadow::red::limits::MinValue;
using mshadow::red::limits::MaxValue;
Stream<xpu> *s = ctx.get_stream<xpu>();
if (inputs[0].type_flag_ == mshadow::kUint8) {
Kernel<dequantize_unsigned, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<float>(),
inputs[0].dptr<uint8_t>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
MinValue<uint8_t>(), MaxValue<uint8_t>());
} else if (inputs[0].type_flag_ == mshadow::kInt8) {
Kernel<dequantize_zero_centered, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<float>(),
inputs[0].dptr<int8_t>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
} else {
LOG(FATAL) << "dequantize op only supports input type int8 or uint8";
}
}

inline bool DequantizeShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down Expand Up @@ -119,6 +95,44 @@ inline bool DequantizeType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}

template <typename xpu>
class DequantizeOperator {
public:
explicit DequantizeOperator(const nnvm::NodeAttrs &attrs) : attrs_(attrs) {}
void Forward(const OpContext &ctx, const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req, const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
using mshadow::red::limits::MaxValue;
using mshadow::red::limits::MinValue;
Stream<xpu> *s = ctx.get_stream<xpu>();
if (inputs[0].type_flag_ == mshadow::kUint8) {
Kernel<dequantize_unsigned, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<float>(),
inputs[0].dptr<uint8_t>(), inputs[1].dptr<float>(),
inputs[2].dptr<float>(), MinValue<uint8_t>(),
MaxValue<uint8_t>());
} else if (inputs[0].type_flag_ == mshadow::kInt8) {
Kernel<dequantize_zero_centered, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<float>(), inputs[0].dptr<int8_t>(),
inputs[1].dptr<float>(), inputs[2].dptr<float>(),
MinAbs(MaxValue<int8_t>(), MinValue<int8_t>()));
} else {
LOG(FATAL) << "dequantize op only supports input type int8 or uint8";
}
}

private:
nnvm::NodeAttrs attrs_;
};

template <typename xpu>
static void DequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx,
const std::vector<TBlob> &inputs, const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
auto &op = state_ptr.get_state<DequantizeOperator<xpu>>();
op.Forward(ctx, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_QUANTIZATION_DEQUANTIZE_INL_H_
21 changes: 19 additions & 2 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ bool DequantizeStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

static OpStatePtr CreateDequantizeState(const nnvm::NodeAttrs &attrs, Context ctx,
const std::vector<TShape> &in_shapes,
const std::vector<int> &in_types) {
OpStatePtr state;
if (ctx.dev_type == kGPU) {
state = OpStatePtr::Create<DequantizeOperator<gpu>>(attrs);
} else {
#if MXNET_USE_MKLDNN == 1
state = OpStatePtr::Create<SgMKLDNNDequantizeOperator>(attrs);
#else
state = OpStatePtr::Create<DequantizeOperator<cpu>>(attrs);
#endif
}
return state;
}

NNVM_REGISTER_OP(_contrib_dequantize)
.describe(R"code(Dequantize the input tensor into a float tensor.
min_range and max_range are scalar floats that specify the range for
Expand All @@ -74,11 +90,12 @@ by keep zero centered for the quantized value:
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FCreateOpState>("FCreateOpState", CreateDequantizeState)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNDequantizeForward)
#endif
.set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", DequantizeForward<cpu>)
.add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`")
.add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value "
"possibly produced for the input in float32")
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/dequantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_contrib_dequantize)
.set_attr<FCompute>("FCompute<gpu>", DequantizeCompute<gpu>);
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DequantizeForward<gpu>);

} // namespace op
} // namespace mxnet
140 changes: 81 additions & 59 deletions src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,104 @@
#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#if MXNET_USE_MKLDNN == 1
#include <string>
#include <algorithm>
#include <string>
#include <vector>
#include "../../nn/mkldnn/mkldnn_base-inl.h"

namespace mxnet {
namespace op {

template<typename SrcType, typename DstType>
static void MKLDNNDequantizeComputeKer(const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs,
const std::vector<OpReqType> &req) {
using namespace mshadow;
using namespace mxnet_op;
using red::limits::MaxValue;
using red::limits::MinValue;
float real_range = 0.0;
float quantized_range = 0.0;
if (inputs[0].dtype() == mshadow::kUint8) {
quantized_range = MaxAbs(MaxValue<SrcType>(), MinValue<SrcType>());
real_range = MaxAbs(*inputs[1].data().dptr<DstType>(), *inputs[2].data().dptr<DstType>());
} else if (inputs[0].dtype() == mshadow::kInt8) {
quantized_range = MinAbs(MaxValue<SrcType>(), MinValue<SrcType>());
real_range = MaxAbs(*inputs[1].data().dptr<DstType>(), *inputs[2].data().dptr<DstType>());
} else {
LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
}
float scale = real_range / quantized_range;
primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();

NDArray in_buffer = inputs[0];
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
class SgMKLDNNDequantizeOperator {
public:
explicit SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs &attrs)
: param_(nnvm::get<DequantizeParam>(attrs.parsed)) {}

void Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs);

private:
bool initalized_{false};
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
DequantizeParam param_;
float cached_data_min_{0.f};
float cached_data_max_{0.f};
std::shared_ptr<mkldnn::memory> i_mem_;
std::shared_ptr<mkldnn::memory> o_mem_;
std::shared_ptr<mkldnn::reorder> fwd_pd_;
};

void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
NDArray in_buffer = inputs[0];
if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default();
auto i_mem = in_buffer.GetMKLDNNData();
auto i_mpd = i_mem->get_primitive_desc();
auto i_desc = i_mpd.desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
}
mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
if (i_fmt == mkldnn::memory::format::nhwc) {
// For 4d tensor, nchw is the default format
i_fmt = mkldnn::memory::format::nchw;
float data_min = *inputs[1].data().dptr<float>();
float data_max = *inputs[2].data().dptr<float>();

if (initalized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max))
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
initalized_ = false;

if (!initalized_) {
cached_data_min_ = data_min;
cached_data_max_ = data_max;
float real_range = MaxAbs(cached_data_min_, cached_data_max_);
float quantized_range = 0.0;
if (inputs[0].dtype() == mshadow::kUint8) {
quantized_range = kUint8Range;
} else if (inputs[0].dtype() == mshadow::kInt8) {
quantized_range = kInt8Range;
real_range = MaxAbs(*inputs[1].data().dptr<float>(), *inputs[2].data().dptr<float>());
} else {
LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
}
float scale = real_range / quantized_range;
primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
auto i_mpd = i_mem->get_primitive_desc();
auto i_desc = i_mpd.desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
}
mkldnn::memory::format o_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
if (o_fmt == mkldnn::memory::format::nhwc) {
// For 4d tensor, nchw is the default format
o_fmt = mkldnn::memory::format::nchw;
}
auto o_desc =
mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum<float>::type, o_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
i_mem_ = std::make_shared<mkldnn::memory>(i_mpd, nullptr);
o_mem_ = std::make_shared<mkldnn::memory>(o_mpd, nullptr);
fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *i_mem_, *o_mem_);
initalized_ = true;
}
auto o_desc = mkldnn::memory::desc(i_dims,
(mkldnn::memory::data_type)data_type_enum<DstType>::type,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]);
i_mem_->set_data_handle(i_mem->get_data_handle());
o_mem_->set_data_handle(o_mem.second->get_data_handle());
MKLDNNStream::Get()->RegisterPrim(*fwd_pd_);
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
}

static void MKLDNNDequantizeCompute(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
if (inputs[0].dtype() == mshadow::kUint8) {
MKLDNNDequantizeComputeKer<uint8_t, float>(inputs, outputs, req);
} else if (inputs[0].dtype() == mshadow::kInt8) {
MKLDNNDequantizeComputeKer<int8_t, float>(inputs, outputs, req);
} else {
LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as input type";
}
static void SgMKLDNNDequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
SgMKLDNNDequantizeOperator &op = state_ptr.get_state<SgMKLDNNDequantizeOperator>();
op.Forward(ctx, inputs, req, outputs);
}



} // namespace op
} // namespace mxnet

Expand Down
Loading