-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) #14713
MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) #14713
Conversation
@lihaofd We need to upgrade MKL-DNN by a separated PR but you can use this PR for the CI testing. |
14a947b
to
f747503
Compare
FYI @anirudh2290 @szha we are starting the MKLDNN RNN integration :) |
Great to hear! Looking forward to this |
d336701
to
1f84682
Compare
3769be5
to
5324c93
Compare
sync code to latest
auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); | ||
MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); | ||
MKLDNNStream::Get()->Submit(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leverage the concat implementation in mkldnn_concat.cc? Do you think the concat primitive here need be cached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are too many data segments with different size, dim, count and concat_dim etc. It will make mkldnn cache be much more complicated but will not benefit too much on perf
mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc | ||
|
||
std::vector<float> weights_scales(ngates * H); | ||
if (!cached) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's cached? How is it cached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the first time, it will did data preparation like wx, wh concat/reorder etc for multiple layers unidirectional or bidirectional ways and saved them into mkldnn cached memory. From next time call, these data be used directly.
sync to latest code
fix min max on zero-sized ndarray (apache#14745)
sync to lastest code
sync code to latest
sync to latest code
…to mkldnn_lstm_infer_fp32
…to mkldnn_lstm_infer_fp32
@TaoLv please take a review. |
The PR is almost done and we're waiting for the local test. |
src/operator/rnn-inl.h
Outdated
hy_ptr, | ||
cy_ptr, | ||
param_.mode); | ||
#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this CUDACC forgot or intend to leave?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it can be removed also. @zixuanweeei
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration looks good to me in general. We can revisit the GRU integration and training part in following PRs.
src/operator/rnn-inl.h
Outdated
hy_ptr, | ||
cy_ptr, | ||
param_.mode); | ||
#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it can be removed also. @zixuanweeei
cy_ptr, | ||
param_.mode); | ||
#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) | ||
if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Please review. We add a new environmental variable here. Once it's set to 0, RNN operator will fall back to the original version on CPU. Otherwise, MKL-DNN RNN primitive will be invoked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it's good even it's not perfect till now.
Our team will continuously work to improve RNN interface :)
We also test the performances over GPU of our PR and the master. Here is the the result. The relative DIFF is calculated by (Our_PR - MASTER) / MASTER. In summary, Our modifications have no significant damage to the performance over GPU. Layer = 1, bidirectional = False
Layer = 1, bidirectional = True
Layer = 5, bidirectional = False
Layer = 5, bidirectional = True
|
I think it prudent to resolve the GPU platform issues with rnn-inl.h introduced by commit 1c49e40 before finally accepting this PR [see /~https://github.com//issues/15034]. Besides introducing test failures of test_rnntanh_bidirectional on P40 GPUs, I have noticed that the codebase no longer compiles against cuDNN versions < 7.0. I intend to submit a PR to resolve both these issues within 24 hours, probably less. |
The PR is ready to be merged. @szha @DickJC123 do we need to wait for #15056 ? |
Merging this one first since the #15056 WIP. |
…u) (apache#14713) * trigger the ci * integrate mkldnn rnn fp32 inference(LSTM and vRNN with tanh and relu) * fix bug about comparison between signed and unsigned integer expressions * fix unix-gpu issue * fix unix gpu bug * fix unix-gpu issues * fix some comments * fix issue * fix comment * rename `cached` to `initialized` * support IType * TODO for MKLDNN GRU * fix bugs in memory adjustment * Reformat TODO for MKLDNN GRU * Reserve original RNN path * Remove MKLDNN GRU * Fix bug for rnn forward * Remove `__CUDAACC__` * Move `RNNStatefulComputeCPU` to rnn.cc * Remove redundent macro of `__CUDACC__` * Remove the last macro `__CUDACC__` from rnn*
Description
In this PR, it integrated MKLDNN RNN Inference Integration(fp32 lstm and vRNN with tanh and relu)
@pengzhao-intel, @TaoLv , @ciyongch
Feature changes
New features
Unit-test changes
Performance
We have tested performance of FusedRNN (USE_MKLDNN = 0 and 1) on our local Skylake-8180 with 1 Sockets and 28 cores. Use MKL as blas lib in this performance test.
Test input size is from DS2 default parameters(seq_length = 300, batch_size = 20, input_size = 800, hidden_size = 800). with MKLDNN commit 57e1203092f63941475ec4088ccd3cf609ed9d7a
Layer=1 bidirectional = False
Layer=5 bidirectional = True
Checklist