From e95b551a69a54aac39b7c4b8d6e8423b1a782c2a Mon Sep 17 00:00:00 2001 From: ciyong Date: Thu, 16 May 2019 06:30:06 +0800 Subject: [PATCH] Add primitive cache for MKL-DNN sum(elemwise_add operator (#14914) * Add primitive cache for mkldnn sum * fix cpp test failure --- src/operator/nn/mkldnn/mkldnn_sum.cc | 105 +++++++++++++++--- .../tensor/elemwise_binary_op_basic.cc | 8 +- 2 files changed, 94 insertions(+), 19 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index dfb0e254c128..724b8a2613d6 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -24,6 +24,7 @@ */ #include +#include "../../operator_common.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" @@ -58,37 +59,105 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out)); } +class MKLDNNSumFwd { + public: + mkldnn::sum::primitive_desc fwd_pd; + + MKLDNNSumFwd(const std::vector &scales, + const std::vector &data_md) + : fwd_pd(scales, data_md) { + data_.resize(data_md.size()); + } + + void SetNewMem(const std::vector &in_data, const mkldnn::memory &output); + + const mkldnn::sum &GetFwd() const { return *fwd_; } + + private: + std::shared_ptr fwd_; + std::vector> data_; + std::vector data_mem_; + std::shared_ptr out_; +}; + +static MKLDNNSumFwd &GetSumForward( + const std::vector &scales, const std::vector &in_data, + const std::vector &data_md) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + OpSignature key; + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNSumFwd fwd(scales, data_md); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +void MKLDNNSumFwd::SetNewMem(const std::vector &in_data, + const mkldnn::memory &output) { + auto num_inputs = data_.size(); + CHECK_EQ(in_data.size(), num_inputs); + for (index_t i = 0; i < static_cast(num_inputs); ++i) { + if (this->data_[i] == nullptr) { + this->data_[i] = std::shared_ptr( + new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); + this->data_mem_.push_back(*this->data_[i]); + } else { + this->data_[i]->set_data_handle(in_data[i]->get_data_handle()); + } + } + if (this->out_ == nullptr) + this->out_ = std::shared_ptr( + new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out_->set_data_handle(output.get_data_handle()); + + if (this->fwd_ == nullptr) + this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_)); +} + void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const OpReqType &req, const NDArray &out_data) { - if (req == kNullOp) { - return; - } - TmpMemMgr::Get()->Init(ctx.requested[0]); - std::vector in_prims; - std::vector in_pds(inputs.size()); - std::vector scales(inputs.size(), 1); - in_prims.reserve(inputs.size()); - std::vector in_bufs(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { + auto num_inputs = inputs.size(); + std::vector data_md; + std::vector data_mem; + std::vector scales(num_inputs, 1); + std::vector in_bufs(num_inputs); + + data_md.reserve(num_inputs); + data_mem.reserve(num_inputs); + + for (index_t i = 0; i < static_cast(num_inputs); ++i) { const mkldnn::memory *in_mem; if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) { in_bufs[i] = inputs[i].Reorder2Default(); in_mem = in_bufs[i].GetMKLDNNData(); } else { + in_bufs[i] = inputs[i]; in_mem = inputs[i].GetMKLDNNData(); } - in_prims.push_back(*in_mem); - in_pds[i] = in_mem->get_primitive_desc(); + mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc(); + data_md.push_back(tmp_pd); + data_mem.push_back(in_mem); } - mkldnn::sum::primitive_desc pdesc(scales, in_pds); - auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req, &inputs[0]); - MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second)); - CommitOutput(out_data, mem); - stream->Submit(); + MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md); + mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data, + fwd.fwd_pd.dst_primitive_desc(), + req, + &in_bufs[0]); + fwd.SetNewMem(data_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + CommitOutput(out_data, out_mem); + MKLDNNStream::Get()->Submit(); } } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 0ff73f4251cd..c5e30c68de7e 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -30,6 +30,12 @@ namespace mxnet { namespace op { +bool SupportMKLDNNSum(const NDArray& input) { + int ndim = input.shape().ndim(); + return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) && + input.storage_type() == kDefaultStorage; +} + static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -38,7 +44,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); #if MXNET_USE_MKLDNN == 1 - if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) { + if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) { MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]); return; } else if (inputs[0].storage_type() == kDefaultStorage