Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
api change
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin committed Jan 17, 2019
1 parent 208beaa commit f1a00cd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
10 changes: 7 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,13 @@ class NDArray {
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
* static_data If true, mkldnn memory won't be freed on destruction.
*/
explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
/*
* Create NDArray from mkldnn memory descriptor.
* mem_pd The mkldnn memory descriptor to be created.
*/
explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
Expand Down Expand Up @@ -776,7 +780,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc();
void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
#endif

/*!
Expand Down
39 changes: 24 additions & 15 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {

#if MXNET_USE_MKLDNN == 1

NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_);
ptr_ = std::make_shared<Chunk>(data, 0);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->CheckAndAlloc(mem_pd.get_size());
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
ptr_->static_data = static_data;
}

NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
ptr_->shandle.size = mem_pd.get_size();
ptr_->delay_alloc = false;
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
ptr_->static_data = true;
}

NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
Expand Down Expand Up @@ -717,19 +729,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
return ptr_->mkl_mem_->GetRaw();
}

void NDArray::UpdateMKLDNNMemDesc() {
void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
const mkldnn::memory *mem = GetMKLDNNData();
auto mem_desc = mem->get_primitive_desc().desc();
auto this_dtype = get_mkldnn_type(dtype());
if (this_dtype != mem_desc.data.data_type) {
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims,
mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, format);
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
#endif

Expand Down
12 changes: 9 additions & 3 deletions src/operator/subgraph/mkldnn/mkldnn_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
}
if (!inplace_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
output = NDArray(outputs[kOut].GetMKLDNNData());
auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
mkldnn_mem_ptr tmp_mem(
new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle()));
MKLDNNStream::Get()->RegisterMem(tmp_mem);
mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
output = NDArray(tmp_mem);
}
}

Expand Down Expand Up @@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,

if (mkldnn_param.with_sum) {
auto out = const_cast<NDArray &>(outputs[kOut]);
out.UpdateMKLDNNMemDesc();
auto format = static_cast<mkldnn::memory::format>(
fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
out.UpdateMKLDNNMemDesc(format);
}
}

Expand Down

0 comments on commit f1a00cd

Please sign in to comment.