Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add batch norm test #13625

Merged
merged 7 commits into from
Dec 13, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 150 additions & 1 deletion tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in
return attrs;
}

OpAttrs GetBNOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("BatchNorm");
attrs.num_inputs = 5;
attrs.num_outputs = 3;
attrs.accept_dims.insert(4);
attrs.requests.insert(OpReqType::kWriteTo);
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.input_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
attrs.output_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
return attrs;
}

OpAttrs GetBNBackwardOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("_backward_BatchNorm");
attrs.num_inputs = 8;
attrs.num_outputs = 3;
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.requests.insert(OpReqType::kWriteTo);
return attrs;
}

void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
Expand Down Expand Up @@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {

// If the array is a view, we shouldn't write data to it.
if (in_arr.arr.IsView())
continue;
continue;

NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < forward_attrs.num_inputs; i++)
Expand All @@ -735,6 +760,124 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}


void TestOpExBNBackward(const OpAttrs &forward_attrs,
const OpAttrs &backwards_attrs,
const OpReqType &req,
const std::vector<NDArray*> &inputs,
const std::vector<NDArray*> &outputs,
const NDArrayAttrs &in_arr,
NDArrayAttrs* out_arr) {
std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);

std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);

std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
std::vector<OpReqType> backwards_req(backwards_attrs.num_outputs);

if (req == kWriteTo) {
backwards_input[0] = &(out_arr->arr); // output grad
backwards_input[1] = outputs[1]; // mean
backwards_input[2] = outputs[2]; // var
backwards_input[3] = inputs[0]; // data
backwards_input[4] = inputs[1]; // gamma
backwards_input[5] = inputs[2]; // beta
backwards_input[6] = inputs[3]; // moving mean
backwards_input[7] = inputs[4]; // moving var

for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
auto tmp_output = in_arr.arr;
backwards_buffer.emplace_back(tmp_output.Copy(Context()));
backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
backwards_outputs[i] = &backwards_buffer.back();
backwards_ex_outputs[i] = &backwards_buffer2.back();
Engine::Get()->WaitForAll();
backwards_req[i] = kWriteTo;
}

std::cout << "Backwards: ";
PrintVerifyMsg(*out_arr, in_arr);
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
backwards_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
backwards_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
}
}

// compares output of fcompute with fcomputex
void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
std::vector<OpReqType> req(forward_attrs.num_outputs);

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, false);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
for (int i1 = 0; i1 < in_arrs.size(); i1++) {
auto in_arr = in_arrs[i1];

CHECK_NE(forward_attrs.accept_dims.size(), 0);
if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
forward_attrs.accept_dims.end())
continue;
for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
}
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
inputs_buffer.clear();
inputs2_buffer.clear();

for (int i = 0; i < forward_attrs.num_inputs; i++) {
inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
Engine::Get()->WaitForAll();
inputs[i] = &inputs_buffer.back();
inputs2[i] = &inputs2_buffer.back();
}
for (int i = 0; i < forward_attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
}
Imperative::Get()->set_is_training(true);

PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs, outputs, req,
DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(outputs, ex_outputs);

if (!backwards_attrs.requests.empty()) {
TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
inputs, outputs, in_arr, &out_arrs[0][output_i]);
}
}
}
}
}

// Computes second dimension of FC weight matrix based on input shape
uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
uint32_t dim = 1;
Expand Down Expand Up @@ -1204,4 +1347,10 @@ TEST(IMPERATIVE, DeconvOp) {
}
}

TEST(IMPERATIVE, BNOp) {
OpAttrs forward_attrs = GetBNOp();
OpAttrs backwards_attrs = GetBNBackwardOp();
TestOpExBN(forward_attrs, backwards_attrs);
}

#endif