Skip to content

Commit

Permalink
Integrate MKLDNN Conv1d and support 3d layout (apache#13530)
Browse files Browse the repository at this point in the history
* add 3d layout support for MKLDNN Conv and Activation

* fix lint

* code refactor

* add testcase for group1 conv and skip quantization for conv1d

* fix lint

* avoid conv1d quantization

* code refactor and add activation ut

* del todo
  • Loading branch information
xinyu-intel authored and haohuw committed Jun 23, 2019
1 parent 980ea17 commit f6e8b24
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 120 deletions.
10 changes: 2 additions & 8 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,24 +453,18 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {

mkldnn::memory::dims dims;
// These are shapes supprted by MKLDNN.
if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
|| shape.ndim() == 5) {
if (shape.ndim() >= 1 && shape.ndim() <= 5) {
dims.resize(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape[i];
} else if (shape.ndim() == 3) {
// If there are 3 dimensions, we'll force it to 4 dimensions.
dims.resize(shape.ndim() + 1);
dims[0] = 1;
for (size_t i = 0; i < shape.ndim(); i++)
dims[i + 1] = shape[i];
} else {
LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
}
mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
switch (dims.size()) {
case 1: layout = mkldnn::memory::format::x; break;
case 2: layout = mkldnn::memory::format::nc; break;
case 3: layout = mkldnn::memory::format::ncw; break;
case 4: layout = mkldnn::memory::format::nchw; break;
// This isn't the right layout when the data has 5 dimensions in MXNet.
// MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
Expand Down
5 changes: 3 additions & 2 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand All @@ -115,7 +116,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNN(inputs[0])) {
if (SupportMKLDNNAct(param, inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
Expand Down
9 changes: 9 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
|| param.act_type == activation::kTanh;
}

bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
(input.dtype() != mshadow::kFloat32))
return false;
return SupportMKLDNNAct(param);
}

static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
Expand Down
24 changes: 17 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,11 @@ struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
}
} // namespace op

static int GetTypeSize(int dtype) {
int size = -1;
Expand Down Expand Up @@ -250,15 +251,24 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr) {

inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
int num_groups) {
auto ndim = arr.shape().ndim();
mkldnn::memory::dims tz = mkldnn::memory::dims{0};
if (num_groups == 1) {
return GetMemDesc(arr);
} else {
CHECK_EQ(arr.shape().ndim(), 4U);
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
CHECK((ndim == 3) || (ndim == 4))
<< "MKL-DNN weight currectly supports 3d and 4d layout";
const int N = 0, H = 2, W = 3, C = 1;
if (ndim == 3) {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
} else {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
}
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
}
Expand Down
99 changes: 73 additions & 26 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,39 +239,49 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
return mem;

mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
mkldnn::memory::dims tz = mkldnn::memory::dims{0};
mkldnn::memory::format format = mkldnn::memory::format::format_undef;
auto engine = CpuEngine::Get()->get_engine();
const int O = 0, I = 1, H = 2, W = 3;
if (arr.shape().ndim() == 2) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
} else if (arr.shape().ndim() == 4 && num_groups == 1) {
mkldnn::memory::dims tz = mkldnn::memory::dims{
static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I])};
format = mkldnn::memory::format::oi;
} else if (arr.shape().ndim() == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups,
static_cast<int>(arr.shape()[O] /
num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])};
format = num_groups > 1 ? mkldnn::memory::format::goiw
: mkldnn::memory::format::oiw;
} else if (arr.shape().ndim() == 4) {
mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
static_cast<int>(arr.shape()[0] / num_groups),
static_cast<int>(arr.shape()[1]),
static_cast<int>(arr.shape()[2]),
static_cast<int>(arr.shape()[3])};
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
tz = num_groups > 1
? mkldnn::memory::dims{num_groups,
static_cast<int>(arr.shape()[O] /
num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
format = num_groups > 1 ? mkldnn::memory::format::goihw
: mkldnn::memory::format::oihw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
return nullptr;
}
mkldnn::memory::desc md =
mkldnn::memory::desc{tz, type, format};
mkldnn::memory::primitive_desc pd =
mkldnn::memory::primitive_desc{md, engine};
mem = arr.GetMKLDNNData(pd);
if (mem == nullptr)
mem = arr.GetMKLDNNDataReorder(target_pd);
if (mem->get_primitive_desc() == target_pd) return mem;
Expand All @@ -285,6 +295,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
switch (num_dims) {
case 1: return mkldnn_x;
case 2: return mkldnn_nc;
case 3: return mkldnn_ncw;
case 4: return mkldnn_nchw;
case 5: return mkldnn_goihw;
default:
Expand All @@ -301,6 +312,30 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
return mkldnn_oi;
else
return desc.data.format;
} else if (desc.data.ndims == 3) {
switch (desc.data.format) {
case mkldnn_ncw:
case mkldnn_nwc:
case mkldnn_nCw8c:
case mkldnn_nCw16c:
return mkldnn_ncw;
case mkldnn_oiw:
case mkldnn_wio:
case mkldnn_Owi8o:
case mkldnn_OIw8i8o:
case mkldnn_OIw8o8i:
case mkldnn_OIw16i16o:
case mkldnn_OIw16o16i:
case mkldnn_Oiw16o:
case mkldnn_Owi16o:
case mkldnn_OIw8i16o2i:
case mkldnn_OIw8o16i2o:
case mkldnn_IOw16o16i:
return mkldnn_oiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
Expand Down Expand Up @@ -329,6 +364,18 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
case mkldnn_Ohwi16o:
case mkldnn_OhIw16o4i:
return mkldnn_oihw;
case mkldnn_goiw:
case mkldnn_gOwi8o:
case mkldnn_gOIw8o8i:
case mkldnn_gOIw8i8o:
case mkldnn_gOIw16i16o:
case mkldnn_gOIw16o16i:
case mkldnn_gOiw16o:
case mkldnn_gOwi16o:
case mkldnn_gOIw8i16o2i:
case mkldnn_gOIw8o16i2o:
case mkldnn_gIOw16o16i:
return mkldnn_goiw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
Expand Down
Loading

0 comments on commit f6e8b24

Please sign in to comment.