From 04a7589f60d4c48a17998e82f067169f9dd5d9f0 Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Mon, 8 Apr 2019 12:54:28 +0800 Subject: [PATCH] stateful_quantize --- src/operator/quantization/dequantize-inl.h | 74 ++++-- src/operator/quantization/dequantize.cc | 7 +- src/operator/quantization/dequantize.cu | 2 +- .../mkldnn/mkldnn_dequantize-inl.h | 143 +++++++----- .../mkldnn/mkldnn_quantize_v2-inl.h | 212 ++++++++++-------- src/operator/quantization/quantize_v2-inl.h | 197 +++++++++------- src/operator/quantization/quantize_v2.cc | 7 +- src/operator/quantization/quantize_v2.cu | 2 +- 8 files changed, 376 insertions(+), 268 deletions(-) diff --git a/src/operator/quantization/dequantize-inl.h b/src/operator/quantization/dequantize-inl.h index dcda5a8b4bef..86dbeb13abec 100644 --- a/src/operator/quantization/dequantize-inl.h +++ b/src/operator/quantization/dequantize-inl.h @@ -68,30 +68,6 @@ struct dequantize_zero_centered { } }; -template -void DequantizeCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mxnet_op; - using mshadow::red::limits::MinValue; - using mshadow::red::limits::MaxValue; - Stream *s = ctx.get_stream(); - if (inputs[0].type_flag_ == mshadow::kUint8) { - Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), - MinValue(), MaxValue()); - } else if (inputs[0].type_flag_ == mshadow::kInt8) { - Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), - MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "dequantize op only supports input type int8 or uint8"; - } -} - inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { @@ -119,6 +95,56 @@ inline bool DequantizeType(const nnvm::NodeAttrs& attrs, return (*in_attrs)[0] != -1; } +template +class DequantizeOperator { + public: + DequantizeOperator(const nnvm::NodeAttrs &attrs) : attrs_(attrs) {} + void Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + Stream *s = ctx.get_stream(); + if (inputs[0].type_flag_ == mshadow::kUint8) { + Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr(), + inputs[2].dptr(), MinValue(), + MaxValue()); + } else if (inputs[0].type_flag_ == mshadow::kInt8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), + inputs[1].dptr(), inputs[2].dptr(), + MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "dequantize op only supports input type int8 or uint8"; + } + } + + private: + nnvm::NodeAttrs attrs_; +}; + +static OpStatePtr CreateDequantizeState(const nnvm::NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + OpStatePtr state; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(attrs); + } else { + state = OpStatePtr::Create>(attrs); + } + return state; +} + +template +static void DequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + auto &op = state_ptr.get_state>(); + op.Forward(ctx, inputs, req, outputs); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_QUANTIZATION_DEQUANTIZE_INL_H_ diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index 7c84673095f0..ed3bbdcbb845 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -76,9 +76,12 @@ by keep zero centered for the quantized value: .set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) -.set_attr("FComputeEx", MKLDNNDequantizeCompute) +.set_attr("FCreateOpState", CreateSgMKLDNNDequantizeState) +.set_attr("FStatefulComputeEx", SgMKLDNNDequantizeForward) +#else +.set_attr("FCreateOpState", CreateDequantizeState) +.set_attr("FStatefulCompute", DequantizeForward) #endif -.set_attr("FCompute", DequantizeCompute) .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`") .add_argument("min_range", "NDArray-or-Symbol", "The minimum scalar value " "possibly produced for the input in float32") diff --git a/src/operator/quantization/dequantize.cu b/src/operator/quantization/dequantize.cu index ca5f91c5def9..41b6e7d20494 100644 --- a/src/operator/quantization/dequantize.cu +++ b/src/operator/quantization/dequantize.cu @@ -28,7 +28,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_contrib_dequantize) -.set_attr("FCompute", DequantizeCompute); +.set_attr("FStatefulCompute", DequantizeForward); } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h index b66adf787fef..c142d3832004 100644 --- a/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h @@ -26,80 +26,105 @@ #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_ #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_ #if MXNET_USE_MKLDNN == 1 -#include #include +#include #include #include "../../nn/mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { -template -static void MKLDNNDequantizeComputeKer(const std::vector &inputs, - const std::vector &outputs, - const std::vector &req) { - using namespace mshadow; - using namespace mxnet_op; - using red::limits::MaxValue; - using red::limits::MinValue; - float real_range = 0.0; - float quantized_range = 0.0; - if (inputs[0].dtype() == mshadow::kUint8) { - quantized_range = MaxAbs(MaxValue(), MinValue()); - real_range = MaxAbs(*inputs[1].data().dptr(), *inputs[2].data().dptr()); - } else if (inputs[0].dtype() == mshadow::kInt8) { - quantized_range = MinAbs(MaxValue(), MinValue()); - real_range = MaxAbs(*inputs[1].data().dptr(), *inputs[2].data().dptr()); - } else { - LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type"; - } - float scale = real_range / quantized_range; - primitive_attr attr; - const int mask = 0; - std::vector scales = {scale}; - attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); - mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); +class SgMKLDNNDequantizeOperator : public DequantizeOperator { + public: + explicit SgMKLDNNDequantizeOperator(const nnvm::NodeAttrs &attrs) + : DequantizeOperator(attrs), param_(nnvm::get(attrs.parsed)) {} - NDArray in_buffer = inputs[0]; - if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) - in_buffer = inputs[0].Reorder2Default(); + void Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, const std::vector &outputs); + + private: + bool initalized_{false}; + DequantizeParam param_; + float cached_data_min_{0.f}; + float cached_data_max_{0.f}; + std::shared_ptr i_mem_; + std::shared_ptr o_mem_; + std::shared_ptr fwd_pd_; +}; +void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + NDArray in_buffer = inputs[0]; + if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nhwc) { - // For 4d tensor, nchw is the default format - i_fmt = mkldnn::memory::format::nchw; + float data_min = *inputs[1].data().dptr(); + float data_max = *inputs[2].data().dptr(); + + if (initalized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max)) + initalized_ = false; + + if (!initalized_) { + cached_data_min_ = data_min; + cached_data_max_ = data_max; + float real_range = MaxAbs(cached_data_min_, cached_data_max_); + float quantized_range = 0.0; + if (inputs[0].dtype() == mshadow::kUint8) { + quantized_range = kUint8Range; + } else if (inputs[0].dtype() == mshadow::kInt8) { + quantized_range = kInt8Range; + real_range = MaxAbs(*inputs[1].data().dptr(), *inputs[2].data().dptr()); + } else { + LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type"; + } + float scale = real_range / quantized_range; + primitive_attr attr; + const int mask = 0; + std::vector scales = {scale}; + attr.set_output_scales(mask, scales); + attr.set_int_output_round_mode(round_nearest); + mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + auto i_mpd = i_mem->get_primitive_desc(); + auto i_desc = i_mpd.desc(); + size_t i_ndim = in_buffer.shape().ndim(); + mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); + for (size_t i = 0; i < i_ndim; i++) { + i_dims[i] = static_cast(in_buffer.shape()[i]); + } + mkldnn::memory::format o_fmt = static_cast(i_desc.data.format); + if (o_fmt == mkldnn::memory::format::nhwc) { + // For 4d tensor, nchw is the default format + o_fmt = mkldnn::memory::format::nchw; + } + auto o_desc = + mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, o_fmt); + auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); + auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); + i_mem_ = std::make_shared(i_mpd, nullptr); + o_mem_ = std::make_shared(o_mpd, nullptr); + fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + initalized_ = true; } - auto o_desc = mkldnn::memory::desc(i_dims, - (mkldnn::memory::data_type)data_type_enum::type, - i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); + auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); + i_mem_->set_data_handle(i_mem->get_data_handle()); + o_mem_->set_data_handle(o_mem.second->get_data_handle()); + MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); CommitOutput(outputs[0], o_mem); MKLDNNStream::Get()->Submit(); } -static void MKLDNNDequantizeCompute(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - if (inputs[0].dtype() == mshadow::kUint8) { - MKLDNNDequantizeComputeKer(inputs, outputs, req); - } else if (inputs[0].dtype() == mshadow::kInt8) { - MKLDNNDequantizeComputeKer(inputs, outputs, req); - } else { - LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as input type"; - } +static void SgMKLDNNDequantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNDequantizeOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +static OpStatePtr CreateSgMKLDNNDequantizeState(const nnvm::NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); } } // namespace op diff --git a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h index d6060e54a82c..117195584b68 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h +++ b/src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h @@ -34,99 +34,37 @@ namespace mxnet { namespace op { -template -static void MKLDNNQuantizeComputeKer(const std::vector& inputs, - const std::vector& outputs, - const QuantizeV2Param& param, - const std::vector& req) { - using namespace mshadow; - using namespace mxnet_op; - using red::limits::MaxValue; - using red::limits::MinValue; - SrcType real_range = 0.f; - DstType quantized_range = 0; - NDArray in_buffer = inputs[0]; - SrcType data_min = red::limits::MaxValue(); - SrcType data_max = red::limits::MinValue(); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - data_min = param.min_calib_range.value(); - data_max = param.max_calib_range.value(); - } else { - // no calib info - in_buffer = inputs[0].Reorder2Default(); - auto in_ptr = in_buffer.data().dptr(); - auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - std::vector data_maxs(nthreads, data_max); - std::vector data_mins(nthreads, data_min); -#pragma omp parallel for num_threads(nthreads) - for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { - int tid = omp_get_thread_num(); - if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i]; - if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i]; - } - for (index_t i = 0; i < nthreads; i++) { - if (data_maxs[i] > data_max) data_max = data_maxs[i]; - if (data_mins[i] < data_min) data_min = data_mins[i]; - } - } +class SgMKLDNNQuantizeOperator : public QuantizeV2Operator { + public: + explicit SgMKLDNNQuantizeOperator(const nnvm::NodeAttrs &attrs) + : QuantizeV2Operator(attrs), param_(nnvm::get(attrs.parsed)) {} - auto out_type = GetOutputType(param); - if (out_type == mshadow::kUint8) { - real_range = std::max(0.f, data_max); - quantized_range = MaxValue(); - *outputs[1].data().dptr() = 0.f; - *outputs[2].data().dptr() = real_range; - } else if (out_type == mshadow::kInt8) { - real_range = MaxAbs(data_min, data_max); - quantized_range = MinAbs(MaxValue(), MinValue()); - *outputs[1].data().dptr() = -real_range; - *outputs[2].data().dptr() = real_range; - } else { - LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; - } - float scale = static_cast(quantized_range) / real_range; - - primitive_attr attr; - const int mask = 0; - std::vector scales = {scale}; - attr.set_output_scales(mask, scales); - attr.set_int_output_round_mode(round_nearest); - mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); - - if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); - auto i_mem = in_buffer.GetMKLDNNData(); - auto i_mpd = i_mem->get_primitive_desc(); - auto i_desc = i_mpd.desc(); - mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); - if (i_fmt == mkldnn::memory::format::nchw || - i_fmt == mkldnn::memory::format::nChw8c || - i_fmt == mkldnn_nChw16c) { - i_fmt = mkldnn::memory::format::nhwc; - } - size_t i_ndim = in_buffer.shape().ndim(); - mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); - for (size_t i = 0; i < i_ndim; i++) { - i_dims[i] = static_cast(in_buffer.shape()[i]); - } - auto o_desc = - mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum::type, i_fmt); - auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); - auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); - auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second)); - CommitOutput(outputs[0], o_mem); - MKLDNNStream::Get()->Submit(); -} + void Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, const std::vector &outputs); + + private: + bool initalized_{false}; + QuantizeV2Param param_; + float cached_data_min_{0.f}; + float cached_data_max_{0.f}; + std::shared_ptr i_mem_; + std::shared_ptr o_mem_; + std::shared_ptr fwd_pd_; +}; -static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const QuantizeV2Param& param = nnvm::get(attrs.parsed); +void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + float quantized_range = 0.0; + NDArray in_buffer = inputs[0]; + float data_min = mshadow::red::limits::MaxValue(); + float data_max = mshadow::red::limits::MinValue(); + + // Pass through quantized data if (inputs[0].dtype() == mshadow::kUint8 || inputs[0].dtype() == mshadow::kInt8) { - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - *outputs[1].data().dptr() = param.min_calib_range.value(); - *outputs[2].data().dptr() = param.max_calib_range.value(); + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + *outputs[1].data().dptr() = param_.min_calib_range.value(); + *outputs[2].data().dptr() = param_.max_calib_range.value(); } else { if (inputs[0].dtype() == mshadow::kUint8) { *outputs[1].data().dptr() = 0; @@ -137,21 +75,107 @@ static void MKLDNNQuantizeV2Compute(const nnvm::NodeAttrs& attrs, const OpContex } } if (req[0] != kWriteInplace) { - const_cast(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData()); + const_cast(outputs[0]).CopyFrom(*inputs[0].GetMKLDNNData()); MKLDNNStream::Get()->Submit(); } } else { - auto out_type = GetOutputType(param); + if (in_buffer.IsView() && in_buffer.IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default(); + auto i_mem = in_buffer.GetMKLDNNData(); + + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + data_min = param_.min_calib_range.value(); + data_max = param_.max_calib_range.value(); + } else { + // no calib info + in_buffer = inputs[0].Reorder2Default(); + auto in_ptr = in_buffer.data().dptr(); + auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { + int tid = omp_get_thread_num(); + if (in_ptr[i] > data_maxs[tid]) data_maxs[tid] = in_ptr[i]; + if (in_ptr[i] < data_mins[tid]) data_mins[tid] = in_ptr[i]; + } + for (index_t i = 0; i < nthreads; i++) { + if (data_maxs[i] > data_max) data_max = data_maxs[i]; + if (data_mins[i] < data_min) data_min = data_mins[i]; + } + } + + // Write output min/max + auto out_type = GetOutputType(param_); if (out_type == mshadow::kUint8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + quantized_range = kUint8Range; + *outputs[1].data().dptr() = data_min; + *outputs[2].data().dptr() = data_max; } else if (out_type == mshadow::kInt8) { - MKLDNNQuantizeComputeKer(inputs, outputs, param, req); + float real_range = MaxAbs(data_min, data_max); + quantized_range = kInt8Range; + *outputs[1].data().dptr() = -real_range; + *outputs[2].data().dptr() = real_range; } else { LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type"; } + + if (initalized_ && (cached_data_min_ != data_min || cached_data_max_ != data_max)) + initalized_ = false; + + if (!initalized_) { + cached_data_min_ = data_min; + cached_data_max_ = data_max; + float real_range = MaxAbs(data_min, data_max); + float scale = quantized_range / real_range; + primitive_attr attr; + const int mask = 0; + std::vector scales = {scale}; + attr.set_output_scales(mask, scales); + attr.set_int_output_round_mode(round_nearest); + mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + auto i_mpd = i_mem->get_primitive_desc(); + auto i_desc = i_mpd.desc(); + mkldnn::memory::format i_fmt = static_cast(i_desc.data.format); + if (i_fmt == mkldnn::memory::format::nchw || i_fmt == mkldnn::memory::format::nChw8c || + i_fmt == mkldnn_nChw16c) { + i_fmt = mkldnn::memory::format::nhwc; + } + size_t i_ndim = in_buffer.shape().ndim(); + mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim); + for (size_t i = 0; i < i_ndim; i++) { + i_dims[i] = static_cast(in_buffer.shape()[i]); + } + auto o_desc = mkldnn::memory::desc(i_dims, get_mkldnn_type(out_type), i_fmt); + auto o_mpd = memory::primitive_desc(o_desc, cpu_engine); + auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr); + i_mem_ = std::make_shared(i_mpd, nullptr); + o_mem_ = std::make_shared(o_mpd, nullptr); + fwd_pd_ = std::make_shared(reorder_pd, *i_mem_, *o_mem_); + initalized_ = true; + } + auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]); + i_mem_->set_data_handle(i_mem->get_data_handle()); + o_mem_->set_data_handle(o_mem.second->get_data_handle()); + MKLDNNStream::Get()->RegisterPrim(*fwd_pd_); + CommitOutput(outputs[0], o_mem); + MKLDNNStream::Get()->Submit(); } } +static void SgMKLDNNQuantizeForward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNQuantizeOperator &op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +static OpStatePtr CreateSgMKLDNNQuantizeState(const nnvm::NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 02ace6c39fac..31a10fd54ec0 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -125,95 +125,14 @@ struct quantize_v2_zero_centered { } }; -template -void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, - const std::vector &inputs, const std::vector &req, - const std::vector &outputs) { - using namespace mshadow; - using namespace mxnet_op; - typedef float SrcDType; - using mshadow::red::limits::MaxValue; - using mshadow::red::limits::MinValue; - Stream *s = ctx.get_stream(); - const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); - auto out_type = GetOutputType(param); - if (out_type == mshadow::kUint8 && std::is_same::value) { - LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, " - "please switch to the context of CPU or int8 data type for GPU."; - } - - if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) { - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - *outputs[1].dptr() = param.min_calib_range.value(); - *outputs[2].dptr() = param.max_calib_range.value(); - } else { - if (inputs[0].type_flag_ == mshadow::kUint8) { - *outputs[1].dptr() = 0; - *outputs[2].dptr() = 255; - } else { - *outputs[1].dptr() = -127; - *outputs[2].dptr() = 127; - } - } - UnaryOp::IdentityCompute(attrs, ctx, {inputs[0]}, req, outputs); - } else { - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), - param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; - } - } else { // model is not calibrated - mxnet::TShape src_shape, dst_shape; - const size_t actual_float_size = sizeof(float); - const size_t temp_reduce_size = ConfigReduce( - s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); - Tensor temp_space = ctx.requested[0].get_space_typed( - Shape1(2 * actual_float_size + temp_reduce_size), s); - const int dev_id = ctx.run_ctx.ctx.dev_id; - TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, - dev_id); - TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, - dev_id); - Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, - Shape1(temp_reduce_size), s); - broadcast::Reduce( - s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - broadcast::Reduce( - s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); - if (out_type == mshadow::kUint8) { - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinValue(), MaxValue()); - } else if (out_type == mshadow::kInt8) { // zero-centered quantization - Kernel::Launch( - s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), - outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), - in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); - } else { - LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; - } - } - } -} - -static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { +static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 3U); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1}); - SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1}); return !shape_is_none(out_attrs->at(0)); } @@ -237,6 +156,114 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector return (*in_attrs)[0] != -1; } +template +class QuantizeV2Operator { + public: + QuantizeV2Operator(const nnvm::NodeAttrs &attrs) : attrs_(attrs) {} + + void Forward(const OpContext &ctx, const std::vector &inputs, + const std::vector &req, const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + typedef float SrcDType; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + Stream *s = ctx.get_stream(); + const QuantizeV2Param ¶m = nnvm::get(attrs_.parsed); + auto out_type = GetOutputType(param); + if (out_type == mshadow::kUint8 && std::is_same::value) { + LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, " + "please switch to the context of CPU or int8 data type for GPU."; + } + + if (inputs[0].type_flag_ == mshadow::kUint8 || inputs[0].type_flag_ == mshadow::kInt8) { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + *outputs[1].dptr() = param.min_calib_range.value(); + *outputs[2].dptr() = param.max_calib_range.value(); + } else { + if (inputs[0].type_flag_ == mshadow::kUint8) { + *outputs[1].dptr() = 0; + *outputs[2].dptr() = 255; + } else { + *outputs[1].dptr() = -127; + *outputs[2].dptr() = 127; + } + } + UnaryOp::IdentityCompute(attrs_, ctx, {inputs[0]}, req, outputs); + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), + param.max_calib_range.value(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } else { // model is not calibrated + mxnet::TShape src_shape, dst_shape; + const size_t actual_float_size = sizeof(float); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, mxnet::TShape({1}), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * actual_float_size + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t(reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, + dev_id); + TBlob in_max_t(reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, + dev_id); + Tensor workspace(temp_space.dptr_ + 2 * actual_float_size, + Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + if (out_type == mshadow::kUint8) { + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinValue(), MaxValue()); + } else if (out_type == mshadow::kInt8) { // zero-centered quantization + Kernel::Launch( + s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), + outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(), + in_max_t.dptr(), MinAbs(MaxValue(), MinValue())); + } else { + LOG(FATAL) << "quantize op only supports int8 and uint8 as output type"; + } + } + } + } + + private: + nnvm::NodeAttrs attrs_; +}; + +static OpStatePtr CreateQuantizeV2State(const nnvm::NodeAttrs &attrs, Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + OpStatePtr state; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(attrs); + } else { + state = OpStatePtr::Create>(attrs); + } + return state; +} + +template +static void QuantizeV2Forward(const OpStatePtr &state_ptr, const OpContext &ctx, + const std::vector &inputs, const std::vector &req, + const std::vector &outputs) { + auto &op = state_ptr.get_state>(); + op.Forward(ctx, inputs, req, outputs); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_V2_INL_H_ diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 920100bc9f8b..95858826a4e0 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -88,9 +88,12 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) -.set_attr("FComputeEx", MKLDNNQuantizeV2Compute) +.set_attr("FCreateOpState", CreateSgMKLDNNQuantizeState) +.set_attr("FStatefulComputeEx", SgMKLDNNQuantizeForward) +#else +.set_attr("FCreateOpState", CreateQuantizeV2State) +.set_attr("FStatefulCompute", QuantizeV2Forward) #endif -.set_attr("FCompute", QuantizeV2Compute) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) diff --git a/src/operator/quantization/quantize_v2.cu b/src/operator/quantization/quantize_v2.cu index ab0cf9c5ad0e..0707f41ded94 100644 --- a/src/operator/quantization/quantize_v2.cu +++ b/src/operator/quantization/quantize_v2.cu @@ -28,7 +28,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_contrib_quantize_v2) -.set_attr("FCompute", QuantizeV2Compute); +.set_attr("FStatefulCompute", QuantizeV2Forward); } // namespace op } // namespace mxnet