-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add batch norm test #13625
add batch norm test #13625
Conversation
@TaoLv please review |
@@ -735,6 +760,128 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { | |||
} | |||
} | |||
|
|||
|
|||
void TestOpExBNBackward(const OpAttrs &forward_attrs, | |||
const OpAttrs &backwards_attrs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix indent.
const std::vector<NDArray*> &inputs, | ||
const std::vector<NDArray*> &outputs, | ||
const NDArrayAttrs &in_arr, | ||
const NDArrayAttrs &out_arr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_arr is not used?
|
||
std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs); | ||
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs); | ||
std::vector<OpReqType> back_req(backwards_attrs.num_outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backwards_req
backwards_input[6] = inputs[3]; // moving mean | ||
backwards_input[7] = inputs[4]; // moving var | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant blank line
Engine::Get()->WaitForAll(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
} | ||
|
||
|
||
for (int i = 0; i < backwards_attrs.num_outputs; i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge this for loop into the previous one?
Seems CI doesn't work. Can we expect these test cases to work without those operator changes in #13084? |
Yes, the test cases for BN is (for now) less stringent than the other C++ test cases. Normally, we run through a large set of fixtures (regular NDArray, mkldnn NDarray, reshaped, etc). However, in the BN implementation we are unable to handle reshaped mkldnn array. For example, we cannot call GetMKLDNNData on views (reshaped) arrays /~https://github.com/apache/incubator-mxnet/blob/3d64d15e69ce6afba728a92b18753a868b6c3298/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h#L198 Normally, we call Reorder2Default in this case to produce a new array that isn't a view and then continue on. (/~https://github.com/apache/incubator-mxnet/blob/3d64d15e69ce6afba728a92b18753a868b6c3298/src/operator/nn/mkldnn/mkldnn_act.cc#L164). However, BN is special because we write to the inputs (/~https://github.com/apache/incubator-mxnet/blob/3d64d15e69ce6afba728a92b18753a868b6c3298/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h#L260), so making a copy of inputs, then writing to those copies will not have the desired effect. For now, all this PR does is provide test cases for our current implementations. |
* upstream/master: (38 commits) Feature/mkldnn static (apache#13628) Fix the bug of BidirectionalCell (apache#13575) Set install path for libmxnet.so dynamic lib on Mac OS (apache#13629) add batch norm test (apache#13625) Scripts for building dependency libraries of MXNet (apache#13282) fix quantize pass error when the quantization supported Op are excluded in the model (apache#13596) Optimize C++ API (apache#13496) Fix warning in waitall doc (apache#13618) [MXNET-1225] Always use config.mk in make install instructions (apache#13364) [MXNET-1224]: improve scala maven jni build and packing. (apache#13493) [MXNET-1155] Add scala packageTest utility (apache#13046) fix the Float not showing correctly problem (apache#13617) apache#13385 [Clojure] - Turn examples into integration tests (apache#13554) Add Intel MKL blas to Jenkins (apache#13607) Revert "[MXNET-1198] MXNet Java API (apache#13162)" Reducing the length of setup tutorial (apache#13306) [MXNET-1182] Predictor example (apache#13237) [MXNET-1187] Added Java SSD Inference Tutorial for website (apache#13201) add defaults and clean up the tests (apache#13295) [MXNET-1181] Added command line alternative to IntelliJ in install instructions (apache#13267) ...
Description
Add tests for MKLDNN batch norm operator The creates ndarray inputs and then runs it separately through the native CPU BN operator and MKLDNN BN operator. Creating the BN test requires a different test helper as the BN operation writes to the input, so identical arrays of equal value must be created and then ran through the native / MKLDNN operators.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments