Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
remove fuse_requantize and change fuse_dequantize to enable_float_out…
Browse files Browse the repository at this point in the history
…put.
  • Loading branch information
ciyongch committed Mar 2, 2019
1 parent 1913d8e commit 73f95dc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 39 deletions.
9 changes: 3 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,16 @@ namespace op {

struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
bool quantized;
bool fuse_requantize;
bool fuse_dequantize;
bool enable_float_output;
bool with_relu;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset

DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("enable quantization");
DMLC_DECLARE_FIELD(fuse_requantize).set_default(false)
.describe("Whether to fuse requantize");
DMLC_DECLARE_FIELD(fuse_dequantize).set_default(false)
.describe("Whether to fuse dequantize");
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_relu).set_default(false)
.describe("Add post relu");
DMLC_DECLARE_FIELD(min_calib_range)
Expand Down
11 changes: 6 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
if (full_param.mkldnn_param.with_relu) {
float scale = 1.0f;
float alpha = 0.0f;
float beta = 1.0f;
const float scale = 1.0f;
const float alpha = 0.0f;
const float beta = 1.0f;
ops.append_eltwise(scale, eltwise_relu, alpha, beta);
}
attr.set_post_ops(ops);

if (full_param.mkldnn_param.quantized) {
if (full_param.mkldnn_param.fuse_requantize ||
full_param.mkldnn_param.fuse_dequantize) {
if ((full_param.mkldnn_param.min_calib_range.has_value() &&
full_param.mkldnn_param.max_calib_range.has_value()) ||
full_param.mkldnn_param.enable_float_output) {
int mask = 0;
std::vector<float> scales = {0.0};
if (full_param.requantize_scales.size()) {
Expand Down
35 changes: 16 additions & 19 deletions src/operator/subgraph/mkldnn/mkldnn_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
min_bias = in_data[base_num_inputs + 4].data().dptr<float>()[0];
max_bias = in_data[base_num_inputs + 5].data().dptr<float>()[0];
}
if (!mkldnn_param.fuse_dequantize) {
if (!mkldnn_param.enable_float_output) {
total_num_outputs = base_num_outputs * 3;
min_output_ptr = out_data[1].data().dptr<float>();
max_output_ptr = out_data[2].data().dptr<float>();
Expand Down Expand Up @@ -160,22 +160,17 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
}
}

if (mkldnn_param.fuse_dequantize) {
if (mkldnn_param.enable_float_output) {
full_param_.output_scales[0] = 1.0 / data_scale / weight_scale;
full_param_.requantize_scales.resize(0);
} else if (mkldnn_param.fuse_requantize) {
} else if (mkldnn_param.min_calib_range.has_value() &&
mkldnn_param.max_calib_range.has_value()) {
full_param_.output_scales.resize(0);
if (mkldnn_param.min_calib_range.has_value() &&
mkldnn_param.max_calib_range.has_value()) {
*min_output_ptr = mkldnn_param.min_calib_range.value();
*max_output_ptr = mkldnn_param.max_calib_range.value();
*min_output_ptr = mkldnn_param.min_calib_range.value();
*max_output_ptr = mkldnn_param.max_calib_range.value();

full_param_.requantize_scales[0] = quantized_out_range /
MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale;
} else {
LOG(FATAL)<<
"Failed to fuse requantize due to no min_calib_range and max_calib_range found.";
}
full_param_.requantize_scales[0] = quantized_out_range /
MaxAbs(*min_output_ptr, *max_output_ptr) / data_scale / weight_scale;
} else {
Stream<cpu> *s = ctx.get_stream<cpu>();
mxnet_op::Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
Expand Down Expand Up @@ -246,7 +241,7 @@ static std::vector<std::string> SgMKLDNNFCListInputNames(const NodeAttrs &attrs)
static std::vector<std::string> SgMKLDNNFCListOutputNames(const NodeAttrs &attrs) {
auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
if (full_param.mkldnn_param.quantized) {
if (full_param.mkldnn_param.fuse_dequantize)
if (full_param.mkldnn_param.enable_float_output)
return std::vector<std::string>{"output"};
else
return std::vector<std::string>{"output", "min_output", "max_output"};
Expand Down Expand Up @@ -288,7 +283,7 @@ static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs,
}

out_shapes->at(0) = base_out_shapes[0];
if (!full_param.mkldnn_param.fuse_dequantize) {
if (!full_param.mkldnn_param.enable_float_output) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
Expand Down Expand Up @@ -316,10 +311,11 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs,
}
}

if (full_param.mkldnn_param.fuse_dequantize) {
if (full_param.mkldnn_param.enable_float_output) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else {
if (full_param.mkldnn_param.fuse_requantize) {
if (full_param.mkldnn_param.min_calib_range.has_value() &&
full_param.mkldnn_param.max_calib_range.has_value()) {
if (full_param.mkldnn_param.with_relu) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
Expand Down Expand Up @@ -359,7 +355,7 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs,
}

out_attrs->at(0) = base_out_attrs[0];
if (!full_param.mkldnn_param.fuse_dequantize) {
if (!full_param.mkldnn_param.enable_float_output) {
type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
}
Expand Down Expand Up @@ -412,7 +408,8 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected)
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const &full_param = nnvm::get<MKLDNNFCFullParam>(attrs.parsed);
return (full_param.mkldnn_param.quantized && !full_param.mkldnn_param.fuse_dequantize) ? 3 : 1;
return (full_param.mkldnn_param.quantized &&
!full_param.mkldnn_param.enable_float_output) ? 3 : 1;
})
.set_attr_parser(SgMKLDNNFCParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgMKLDNNFCListInputNames)
Expand Down
17 changes: 8 additions & 9 deletions src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {

private:
bool disable_all;
bool disable_fuse_dequantize;
bool disable_float_output;
SelectStatus status;
std::vector<const nnvm::Node *> matched_list;

public:
explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all,
const bool dis_fuse_dequantize)
const bool dis_float_output)
: disable_all(dis_all),
disable_fuse_dequantize(dis_fuse_dequantize) {}
disable_float_output(dis_float_output) {}

bool Select(const nnvm::Node &n) override {
if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) {
Expand Down Expand Up @@ -94,7 +94,7 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
}
}
case kRequantize:
if ((!disable_fuse_dequantize) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
matched_list.push_back(&new_node);
status = kSuccess;
return true;
Expand Down Expand Up @@ -128,7 +128,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
SgMKLDNNFCPostQuantizeProperty() {
disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_POST_OPT", false);
disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL", false);
disable_fuse_dequantize = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_DEQUANTIZE", false);
disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT", false);

disable_all = disable_all || disable_fuse_all;
if (disable_all) {
Expand Down Expand Up @@ -169,9 +169,8 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
// When only fused quantized_fullyconnected and requantize, set min/max_cablib_range,
// When fused quantized_fullyconnected + requantize + dequantize, set dequantize flag to true.
if (dequantize_node != nullptr) {
fc_node->attrs.dict["fuse_dequantize"] = "True";
fc_node->attrs.dict["enable_float_output"] = "True";
} else {
fc_node->attrs.dict["fuse_requantize"] = "True";
fc_node->attrs.dict["min_calib_range"] =
std::to_string(requantize_param.min_calib_range.value());
fc_node->attrs.dict["max_calib_range"] =
Expand All @@ -184,7 +183,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
SubgraphSelectorPtr CreateSubgraphSelector() const override {
auto selector =
std::make_shared<SgMKLDNNFCPostQuantizeSelector>(disable_all,
disable_fuse_dequantize);
disable_float_output);
return selector;
}

Expand All @@ -200,7 +199,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
private:
bool disable_all;
bool disable_fuse_all;
bool disable_fuse_dequantize;
bool disable_float_output;
};

MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_FC_QUANTIZE, SgMKLDNNFCPostQuantizeProperty);
Expand Down

0 comments on commit 73f95dc

Please sign in to comment.