From 73f95dc25eab5c6b0766ee02c174ca9c47e18926 Mon Sep 17 00:00:00 2001 From: Ciyong Chen Date: Sat, 2 Mar 2019 15:49:18 +0800 Subject: [PATCH] remove fuse_requantize and change fuse_dequantize to enable_float_output. --- .../nn/mkldnn/mkldnn_fully_connected-inl.h | 9 ++--- .../nn/mkldnn/mkldnn_fully_connected.cc | 11 +++--- src/operator/subgraph/mkldnn/mkldnn_fc.cc | 35 +++++++++---------- .../mkldnn_fc_post_quantize_property.cc | 17 +++++---- 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h index e805bb1daec3..a800c1cffda8 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h @@ -38,8 +38,7 @@ namespace op { struct MKLDNNFCParam: public dmlc::Parameter { bool quantized; - bool fuse_requantize; - bool fuse_dequantize; + bool enable_float_output; bool with_relu; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset @@ -47,10 +46,8 @@ struct MKLDNNFCParam: public dmlc::Parameter { DMLC_DECLARE_PARAMETER(MKLDNNFCParam) { DMLC_DECLARE_FIELD(quantized).set_default(false) .describe("enable quantization"); - DMLC_DECLARE_FIELD(fuse_requantize).set_default(false) - .describe("Whether to fuse requantize"); - DMLC_DECLARE_FIELD(fuse_dequantize).set_default(false) - .describe("Whether to fuse dequantize"); + DMLC_DECLARE_FIELD(enable_float_output).set_default(false) + .describe("Whether to enable float32 output"); DMLC_DECLARE_FIELD(with_relu).set_default(false) .describe("Add post relu"); DMLC_DECLARE_FIELD(min_calib_range) diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 876f60e0a733..b515cbce407f 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -44,16 +44,17 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( mkldnn::primitive_attr attr; mkldnn::post_ops ops; if (full_param.mkldnn_param.with_relu) { - float scale = 1.0f; - float alpha = 0.0f; - float beta = 1.0f; + const float scale = 1.0f; + const float alpha = 0.0f; + const float beta = 1.0f; ops.append_eltwise(scale, eltwise_relu, alpha, beta); } attr.set_post_ops(ops); if (full_param.mkldnn_param.quantized) { - if (full_param.mkldnn_param.fuse_requantize || - full_param.mkldnn_param.fuse_dequantize) { + if ((full_param.mkldnn_param.min_calib_range.has_value() && + full_param.mkldnn_param.max_calib_range.has_value()) || + full_param.mkldnn_param.enable_float_output) { int mask = 0; std::vector scales = {0.0}; if (full_param.requantize_scales.size()) { diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index e66cbbcdd762..cfb5df32793c 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -97,7 +97,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, min_bias = in_data[base_num_inputs + 4].data().dptr()[0]; max_bias = in_data[base_num_inputs + 5].data().dptr()[0]; } - if (!mkldnn_param.fuse_dequantize) { + if (!mkldnn_param.enable_float_output) { total_num_outputs = base_num_outputs * 3; min_output_ptr = out_data[1].data().dptr(); max_output_ptr = out_data[2].data().dptr(); @@ -160,22 +160,17 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, } } - if (mkldnn_param.fuse_dequantize) { + if (mkldnn_param.enable_float_output) { full_param_.output_scales[0] = 1.0 / data_scale / weight_scale; full_param_.requantize_scales.resize(0); - } else if (mkldnn_param.fuse_requantize) { + } else if (mkldnn_param.min_calib_range.has_value() && + mkldnn_param.max_calib_range.has_value()) { full_param_.output_scales.resize(0); - if (mkldnn_param.min_calib_range.has_value() && - mkldnn_param.max_calib_range.has_value()) { - *min_output_ptr = mkldnn_param.min_calib_range.value(); - *max_output_ptr = mkldnn_param.max_calib_range.value(); + *min_output_ptr = mkldnn_param.min_calib_range.value(); + *max_output_ptr = mkldnn_param.max_calib_range.value(); - full_param_.requantize_scales[0] = quantized_out_range / - MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale; - } else { - LOG(FATAL)<< - "Failed to fuse requantize due to no min_calib_range and max_calib_range found."; - } + full_param_.requantize_scales[0] = quantized_out_range / + MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale; } else { Stream *s = ctx.get_stream(); mxnet_op::Kernel::Launch(s, 1, @@ -246,7 +241,7 @@ static std::vector SgMKLDNNFCListInputNames(const NodeAttrs &attrs) static std::vector SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) { auto const &full_param = nnvm::get(attrs.parsed); if (full_param.mkldnn_param.quantized) { - if (full_param.mkldnn_param.fuse_dequantize) + if (full_param.mkldnn_param.enable_float_output) return std::vector{"output"}; else return std::vector{"output", "min_output", "max_output"}; @@ -288,7 +283,7 @@ static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs, } out_shapes->at(0) = base_out_shapes[0]; - if (!full_param.mkldnn_param.fuse_dequantize) { + if (!full_param.mkldnn_param.enable_float_output) { SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); } @@ -316,10 +311,11 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs, } } - if (full_param.mkldnn_param.fuse_dequantize) { + if (full_param.mkldnn_param.enable_float_output) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); } else { - if (full_param.mkldnn_param.fuse_requantize) { + if (full_param.mkldnn_param.min_calib_range.has_value() && + full_param.mkldnn_param.max_calib_range.has_value()) { if (full_param.mkldnn_param.with_relu) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); } else { @@ -359,7 +355,7 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs, } out_attrs->at(0) = base_out_attrs[0]; - if (!full_param.mkldnn_param.fuse_dequantize) { + if (!full_param.mkldnn_param.enable_float_output) { type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); } @@ -412,7 +408,8 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) }) .set_num_outputs([](const NodeAttrs& attrs) { auto const &full_param = nnvm::get(attrs.parsed); - return (full_param.mkldnn_param.quantized && !full_param.mkldnn_param.fuse_dequantize) ? 3 : 1; + return (full_param.mkldnn_param.quantized && + !full_param.mkldnn_param.enable_float_output) ? 3 : 1; }) .set_attr_parser(SgMKLDNNFCParamParser) .set_attr("FListInputNames", SgMKLDNNFCListInputNames) diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc index 29a5a293b9f2..05d4654f6484 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc @@ -41,15 +41,15 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { private: bool disable_all; - bool disable_fuse_dequantize; + bool disable_float_output; SelectStatus status; std::vector matched_list; public: explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all, - const bool dis_fuse_dequantize) + const bool dis_float_output) : disable_all(dis_all), - disable_fuse_dequantize(dis_fuse_dequantize) {} + disable_float_output(dis_float_output) {} bool Select(const nnvm::Node &n) override { if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) { @@ -94,7 +94,7 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { } } case kRequantize: - if ((!disable_fuse_dequantize) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { matched_list.push_back(&new_node); status = kSuccess; return true; @@ -128,7 +128,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { SgMKLDNNFCPostQuantizeProperty() { disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_POST_OPT", false); disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL", false); - disable_fuse_dequantize = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_DEQUANTIZE", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT", false); disable_all = disable_all || disable_fuse_all; if (disable_all) { @@ -169,9 +169,8 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { // When only fused quantized_fullyconnected and requantize, set min/max_cablib_range, // When fused quantized_fullyconnected + requantize + dequantize, set dequantize flag to true. if (dequantize_node != nullptr) { - fc_node->attrs.dict["fuse_dequantize"] = "True"; + fc_node->attrs.dict["enable_float_output"] = "True"; } else { - fc_node->attrs.dict["fuse_requantize"] = "True"; fc_node->attrs.dict["min_calib_range"] = std::to_string(requantize_param.min_calib_range.value()); fc_node->attrs.dict["max_calib_range"] = @@ -184,7 +183,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { SubgraphSelectorPtr CreateSubgraphSelector() const override { auto selector = std::make_shared(disable_all, - disable_fuse_dequantize); + disable_float_output); return selector; } @@ -200,7 +199,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { private: bool disable_all; bool disable_fuse_all; - bool disable_fuse_dequantize; + bool disable_float_output; }; MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_FC_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);