Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-498] Test MKLDNN backward operators #11232

Merged
merged 39 commits into from
Jun 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cc1bf0a
add act backwards test
azai91 Jun 11, 2018
92990a9
use only verifyfn template
azai91 Jun 11, 2018
1e0488b
fix param name
azai91 Jun 11, 2018
9d7f30b
update number of inputs
azai91 Jun 11, 2018
1ea1f11
fix assertion for act backwards
azai91 Jun 11, 2018
54766c3
limit rand num range
azai91 Jun 11, 2018
96f19aa
change to assert
azai91 Jun 11, 2018
3b1d194
wait to read on correct vector
azai91 Jun 11, 2018
ea80b00
add writeinplace test
azai91 Jun 11, 2018
b72e475
fix params
azai91 Jun 11, 2018
83b6f45
add copy backwards test
azai91 Jun 11, 2018
fe23881
add missing fixture
azai91 Jun 11, 2018
28a3409
fix lint
azai91 Jun 11, 2018
4832ab7
add sum backwards verify
azai91 Jun 11, 2018
6a61ac2
use correct num of inputs for sum backwards
azai91 Jun 11, 2018
7d4b9b3
switch input / output
azai91 Jun 12, 2018
9e71415
wait for both outputs
azai91 Jun 12, 2018
c3c8a96
limit input/output
azai91 Jun 15, 2018
dc17fa2
limit input/outputs for relu/sum
azai91 Jun 15, 2018
da13928
fix var source
azai91 Jun 15, 2018
49d432a
reorder backwards if view
azai91 Jun 15, 2018
804c7de
add another entry to reqs in ttest
azai91 Jun 15, 2018
33f25f9
uncomment write in place sumbackwards
azai91 Jun 15, 2018
2b44d94
refactor testunary and testbinary into testop
azai91 Jun 15, 2018
4e42d4b
remove special testbackwardsop and use testop
azai91 Jun 15, 2018
77dc89c
fill reqs vector with num of outputs
azai91 Jun 15, 2018
b2d73f8
change req size to num outputs
azai91 Jun 18, 2018
76db6a6
create mulitple output ndarrays
azai91 Jun 18, 2018
528e515
wait for all outputs
azai91 Jun 18, 2018
2a54003
remove unused comments
azai91 Jun 18, 2018
b977341
remove redundant VerifyCopyResult method
azai91 Jun 18, 2018
5a1d899
remove redundant VerifySumResult
azai91 Jun 18, 2018
923d9d1
remove unused var
azai91 Jun 18, 2018
09164ac
use only InitDefaultArray
azai91 Jun 18, 2018
831b7d0
move MKLDNNSum near copy test
azai91 Jun 18, 2018
dbe80b0
use fallback compute for backwards sum
azai91 Jun 18, 2018
b009a74
fix verifydefmem test
azai91 Jun 18, 2018
bfc729e
fix lint
azai91 Jun 18, 2018
6085cce
move MKLDNNSum test back to bottom
azai91 Jun 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,22 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
return;
}

NDArray out_buffer = out_grad;
if (out_grad.IsView() && out_grad.IsMKLDNNData())
out_buffer = out_grad.Reorder2Default();

NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
auto diff_dst_memory = out_grad.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
auto diff_dst_memory = out_buffer.GetMKLDNNData();
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_data.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
Expand All @@ -201,7 +209,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
auto alg = GetMKLDNNActAlgo(param);
mkldnn_output_t diff_src_memory;

MSHADOW_REAL_TYPE_SWITCH(in_data.dtype(), DType, {
MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
Expand Down
5 changes: 5 additions & 0 deletions src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ static void _backward_ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]);
MKLDNNCopy(attrs, ctx, inputs[0], req[1], outputs[1]);
return;
} else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
FallBackCompute(
ElemwiseBinaryOp::BackwardUseNone<cpu, mshadow_op::identity, mshadow_op::identity>,
attrs, ctx, inputs, req, outputs);
return;
}
#endif
ElemwiseBinaryOp::BackwardUseNoneEx<cpu, mshadow_op::identity, mshadow_op::identity>(
Expand Down
Loading