Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix reshape to add in-place back #14903

Merged
merged 13 commits into from
May 14, 2019
6 changes: 6 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);

void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &data,
const OpReqType &req,
const NDArray &output);

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
Expand Down
194 changes: 194 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_reshape.cc
* \brief Implement reshape operator via MKL-DNN reorder primitive
* \author Tao Lv
*/

#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
#include "../../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {

bool SupportMKLDNNReshape(const ReshapeParam &param,
const NDArray &data) {
auto data_ndim = data.shape().ndim();

if (data_ndim > 4 ||
data.dtype() != mshadow::kFloat32 ||
param.shape.ndim() > 4)
return false;

return true;
}

typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;

class MKLDNNReshapeForward {
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;

bool needInvalidateInput = false;

public:
MKLDNNReshapeForward(const ReshapeParam &param,
const OpReqType &req,
const NDArray &input,
const NDArray &output) {
auto engine = CpuEngine::Get()->get_engine();

// data_
auto in_mem = input.GetMKLDNNData();
auto in_pd = in_mem->get_primitive_desc();
data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);

// temp_
auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

// destination
out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

if (req == kWriteInplace) {
// If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
// default layout and copy from the temporal buffer back to output buffer which has the same
// address with input buffer.
// If the input has default layout, then nothing need to do.
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
needInvalidateInput = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the situation of else?

} else if (req == kWriteTo) {
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
needInvalidateInput = false;
} else {
prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
needInvalidateInput = false;
}
} else {
LOG(FATAL) << "not supported req type: " << req;
}
}

int GetWorkspaceSize() {
return temp_ ? temp_->get_primitive_desc().get_size() : 0;
}

void SetNewMem(const NDArray &input, const NDArray &output, void* workspace = nullptr) {
if (input.IsMKLDNNData()) {
this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
} else {
MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
this->data_->set_data_handle(input.data().dptr<DTYPE>());
})
}

if (output.IsMKLDNNData()) {
this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
} else {
MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
this->out_->set_data_handle(output.data().dptr<DTYPE>());
})
}

if (workspace) {
this->temp_->set_data_handle(workspace);
}
}

void Execute(const NDArray &input,
const NDArray &output,
void* workspace = nullptr) {
// set memory handles
SetNewMem(input, output, workspace);
// register primitives
auto stream = MKLDNNStream::Get();
for (auto &v : this->prims_) {
stream->RegisterPrim(v);
}
stream->Submit();
// invalidate mkldnn memory in input
if (needInvalidateInput) {
const_cast<NDArray &>(input).InvalidateMKLDNNData();
}
}
};

static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNReshapeSignature,
MKLDNNReshapeForward, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNReshapeSignature,
MKLDNNReshapeForward, OpHash> fwds;
#endif
MKLDNNReshapeSignature key(param);
key.AddSign(req);
key.AddSign(input);
key.AddSign(output);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNReshapeForward fwd(param, req, input, output);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output) {
const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed);
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";

auto fwd = GetReshapeForward(param, req, input, output);
auto ws_size = fwd.GetWorkspaceSize();
void* ws_ptr = nullptr;
if (ws_size) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
.get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
}
} // namespace op
} // namespace mxnet
#endif
19 changes: 19 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
"If set to true, then the first dim in target_shape is ignored,"
"and always fixed as input");
}

bool operator==(const ReshapeParam &other) const {
return this->target_shape == other.target_shape &&
this->keep_highest == other.keep_highest &&
this->shape == other.shape &&
this->reverse == other.reverse;
}
};

template<typename IType>
Expand Down Expand Up @@ -2867,6 +2874,18 @@ struct hash<mxnet::op::TransposeParam> {
return ret;
}
};

template<>
struct hash<mxnet::op::ReshapeParam> {
size_t operator()(const mxnet::op::ReshapeParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.target_shape);
ret = dmlc::HashCombine(ret, val.keep_highest);
ret = dmlc::HashCombine(ret, val.shape);
ret = dmlc::HashCombine(ret, val.reverse);
return ret;
}
};
} // namespace std

#endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_
28 changes: 3 additions & 25 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,6 @@ DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
DMLC_REGISTER_PARAMETER(SplitParam);

#if MXNET_USE_MKLDNN == 1
void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) {
MSHADOW_TYPE_SWITCH(in_data.dtype(), DType, {
auto this_mem = in_data.GetMKLDNNData();
auto out_dptr = out_data.data().dptr<DType>();
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
mkldnn::memory::desc this_desc = this_pd.desc();
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, this_pd.get_engine());
auto temp_mem = mkldnn::memory(pd, out_dptr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*this_mem, temp_mem));
MKLDNNStream::Get()->Submit();

// Removing out_data mkl_mem_ and store data in the default format
const_cast<NDArray &>(out_data).InvalidateMKLDNNData();
});
}

static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -137,8 +116,8 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in MKLDNN format and
// MKLDNNsupport the data type or the shape. Then convert
// it to the output format and shape
if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape()) && req[0] != kAddTo) {
MKLDNNReshape(inputs[0], outputs[0]);
if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape())) {
MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down Expand Up @@ -234,7 +213,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#else
#endif
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
Expand All @@ -243,7 +222,6 @@ If the argument `reverse` is set to 1, then the special values are inferred from
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to reshape.")
.add_arguments(ReshapeParam::__FIELDS__());

Expand Down
41 changes: 39 additions & 2 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_reshape_after_conv(dst_shape):
res = mx.symbol.reshape(data=conv, shape=dst_shape)
exe = res.simple_bind(mx.cpu(), data=shape, grad_req='null')

val1 = np.random.uniform(-1, 1, (4, 4))
val2 = np.random.uniform(-1, 1, (1, 1, 1, 1))
val1 = np.random.uniform(-1, 1, shape)
val2 = np.random.uniform(-1, 1, (16, 1, 1, 1))
val3 = np.random.uniform(-1 ,1, (1))

exe.arg_arrays[0][:] = val1
Expand Down Expand Up @@ -489,5 +489,42 @@ def test_conv_transpose():
np.allclose(t.asnumpy(), n)


# This test case is contributed by @awsbillz in /~https://github.com/apache/incubator-mxnet/issues/14766
@with_seed()
def test_reshape_transpose_6d():
class Reshape2D(gluon.HybridBlock):
def __init__(self, factor):
super(Reshape2D, self).__init__()
self._factors = (int(factor),) * 2

def hybrid_forward(self, F, x):
f1, f2 = self._factors
# (N, f1*f2*C, H, W)
x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0)) # (N, C, f1*f2, H, W)
x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0)) # (N, C, f1, f2, H, W)
x = F.transpose(x, (0, 1, 4, 2, 5, 3)) # (N, C, H, f1, W, f2)
x = F.reshape(x, (0, 0, -3, -3)) # (N, C, H*f1, W*f2)
return x


class Net(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
with self.name_scope():
self.conv1 = nn.Conv2D(8, kernel_size=5)
self.reshape2D = Reshape2D(2)

def hybrid_forward(self, F, x):
x = self.conv1(x)
x = self.reshape2D(x)
return x

net = Net()
net.initialize(mx.init.Xavier(), ctx=mx.cpu())
net.hybridize()
data = mx.nd.random_normal(shape=(1, 3, 600, 600))
output = net(data)
a = output.asnumpy()

if __name__ == '__main__':
install.test_mkldnn_install()