Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN]Add quantized relu #14604

Merged
merged 13 commits into from
Apr 18, 2019
74 changes: 74 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved

/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_act-inl.h
* \brief MKLDNN(Quantized) Activation operator based on subgraph
* /author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_


#if MXNET_USE_MKLDNN == 1
#include <vector>
#include <utility>
#include "../activation-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unify the format of "&"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const mkldnn::memory &input_mem, int dtype);

class MKLDNNActForward {
public:
const mkldnn::eltwise_forward::primitive_desc fwd_pd;

MKLDNNActForward(const ActivationParam& param, bool is_train,
const NDArray &data, const mkldnn::memory &mem): fwd_pd(
GetActFwdDescImpl(param, is_train, mem, data.dtype())) {}
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
const mkldnn::eltwise_forward &GetFwd() const;

private:
std::shared_ptr<mkldnn::eltwise_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
MKLDNNActForward &GetActForward(const ActivationParam& param,
const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem);

void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
91 changes: 32 additions & 59 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
#include <string>
#include <utility>
#include "../../operator_common.h"
#include "../activation-inl.h"
#include "./mkldnn_base-inl.h"
#include "mkldnn_act-inl.h"

#if MXNET_USE_MKLDNN == 1

Expand All @@ -58,7 +57,7 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
return SupportMKLDNNAct(param);
}

static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
return mkldnn::algorithm::eltwise_relu;
Expand All @@ -74,75 +73,49 @@ static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
}
}

typedef std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> mkldnn_act_pdesc_ptr;

static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
auto cpu_engine = data_mpd.get_engine();

auto alg = GetMKLDNNActAlgo(param);
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc desc = is_train
? mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha)
: mkldnn::eltwise_forward::desc(mkldnn::prop_kind::forward_scoring,
alg, data_md, alpha);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
});
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, alg, data_md, 0.0);

auto prop = is_train ? mkldnn::prop_kind::forward_training :
mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}

typedef ParamOpSign<ActivationParam> MKLDNNActSignature;

class MKLDNNActForward {
std::shared_ptr<mkldnn::eltwise_forward> fwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> out;

public:
const mkldnn::eltwise_forward::primitive_desc fwd_pd;

MKLDNNActForward(const ActivationParam& param, bool is_train,
const NDArray &data, const mkldnn::memory &mem): fwd_pd(
GetActFwdDescImpl(param, is_train, mem, data.dtype())) {
}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());

CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
if (this->out == nullptr)
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out->set_data_handle(output.get_data_handle());

if (this->fwd == nullptr) {
this->fwd = std::shared_ptr<mkldnn::eltwise_forward>(
new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data),
*this->out));
}
void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
data.get_data_handle());
else
this->data_->set_data_handle(data.get_data_handle());

CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
if (this->out_ == nullptr)
this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
output.get_data_handle());
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
*this->out_));
}
}

const mkldnn::eltwise_forward &GetFwd() const {
return *fwd;
}
};
const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
return *fwd_;
}

static MKLDNNActForward &GetActForward(const ActivationParam& param,
const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem) {
MKLDNNActForward &GetActForward(const ActivationParam& param,
const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActForward, OpHash> fwds;
#else
Expand Down
55 changes: 55 additions & 0 deletions src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_quantized_act.cc
* \brief MKLDNN(Quantized) Activation operator based on subgraph
* /author Zhiyuan Huang
*/
#if MXNET_USE_MKLDNN == 1

#include "../../nn/mkldnn/mkldnn_act-inl.h"
#include "../quantization_utils.h"
TaoLv marked this conversation as resolved.
Show resolved Hide resolved

namespace mxnet {
namespace op {

static void MKLDNNQuantizedActForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
CHECK(in_data[0].dtype() == mshadow::kUint8 ||
in_data[0].dtype() == mshadow::kInt8)
<< "_contrib_quantized_act op only supports uint8 and int8 as input "
"type";

MKLDNNActivationForward(attrs, ctx, in_data[0], req[0], out_data[0]);
out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
}

NNVM_REGISTER_OP(_contrib_quantized_act)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedActForward);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
29 changes: 19 additions & 10 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ std::vector<NodeEntry> OfflineParams(std::vector<NodeEntry>&& outputs,
return outputs;
}

inline bool NeedQuantize(const NodePtr node,
const std::unordered_set<std::string>& excluded_nodes) {
inline NodePtr NeedQuantize(NodePtr node, const std::unordered_set<std::string>& excluded_nodes) {
std::unordered_map<NodePtr, NodePtr> quantized_node;
static auto& quantized_op_map = Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
const auto& op = node->op();

if (op && quantized_op_map.count(op)) {
bool need = true;
if (excluded_nodes.count(node->attrs.name)) {
Expand All @@ -112,14 +113,24 @@ inline bool NeedQuantize(const NodePtr node,
});
}
}
return need;

if (need) {
auto n_ptr = quantized_op_map[node->op()];
auto tmp_node = n_ptr(node->attrs);
if (tmp_node->op()) {
quantized_node[node] = tmp_node;
} else {
quantized_node[node] = nullptr;
}
} else {
quantized_node[node] = nullptr;
}
}
return false;
return quantized_node[node];
}

Graph QuantizeGraph(Graph &&src) {
static const auto& flist_outputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
static const auto& quantized_op_map = Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp");
static const auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
Expand All @@ -136,11 +147,9 @@ Graph QuantizeGraph(Graph &&src) {
NodePtr new_node = Node::Create();
// If the currently visited node needs quantization, insert a quantize op node before the
// current node and replace the current node with the quantized version in the new graph.
if (NeedQuantize(node, excluded_nodes)) {
auto fquantized_op = quantized_op_map[node->op()];
// If the currently visited node's op registered the FQuantizedOp property, new_node is a
// quantizated version of a that op, such as quantized_conv2d.
new_node = fquantized_op(node->attrs);
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
auto tmp_node = NeedQuantize(node, excluded_nodes);
if (tmp_node) {
new_node = tmp_node;

// add data into quantized op input
for (size_t i = 0; i < node->inputs.size(); ++i) {
Expand Down
Loading