-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward implementation for LSTM operator. #4929
Conversation
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 | ||
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} | ||
// | ||
struct SeqInfo { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move struct SeqInfo
out of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done by @reyoung
// input LodTensor. It is also the maximum length of input sequence. | ||
|
||
paddle::framework::LoD batch_lods; | ||
batch_lods.push_back(std::vector<size_t>{0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, but is it working to push_back std::vector<size_t>
in GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe emplace_back
is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在GPU下,push_back
/emplace_back
都可以正常使用,都可以直接赋给thrust::host_vector,单测里是没问题的。
|
||
paddle::framework::LoD batch_lods; | ||
batch_lods.push_back(std::vector<size_t>{0}); | ||
batch_lods.push_back(std::vector<size_t>{0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞详细的LSTM注释。
paddle/operators/CMakeLists.txt
Outdated
@@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) | |||
op_library(sum_op DEPS net_op) | |||
op_library(pool_op DEPS pooling) | |||
op_library(pool_with_index_op DEPS pooling) | |||
op_library(lstm_op DEPS sequence2batch lstm_compute math_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要加math_function,12行已经加过依赖了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
AddInput("Input", | ||
"(LoDTensor) the first input is a LodTensor, which support " | ||
"variable-time length input sequence. The underlying tensor in " | ||
"this LoDTenosr is a matrix with shape (T X 4D), where, T is the " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where后面不需要逗号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
"batch size. `H0` and `C0` can be NULL but only at the same time"); | ||
AddInput("Weight", | ||
"(Tensor) the learnable hidden-hidden weights." | ||
" - The shape is (D x 4*D), where D is the hidden size. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4*D-》4D,看89行是用4D,下同。或者都用4*D
的格式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
AddInput("Bias", | ||
"(Tensor) the learnable weights, which contains two parts: " | ||
"input-hidden bias weight and peephole connections weight if " | ||
"seting `usePeepholes` True. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seting-》setting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.cc
Outdated
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); | ||
AddOutput("BatchGate", | ||
"(LoDTensor) This LoDTensor contains input gate, forget gate " | ||
"and output gate aftern the nonlinear computation. This " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aftern-》after,笔误
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/lstm_op.h
Outdated
lstm_value.checkOg = lstm_value.checkFg + frame_size; | ||
lstm_value.prevStateValue = nullptr; | ||
|
||
framework::LoDTensor batch_out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using LoDTensor = framework::LoDTensor;
这里直接写LoDTensor会更清爽
79,81,83能写成一行么:
LoDTensor batch_out, batch_cell, batch_cell_pre_act
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
T rState; | ||
T rPrevState = 0; | ||
T rStateAtv; | ||
T rOut; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
33-43行可以合并一些么?
T rValueIn, rValueIg, rValueIg, rValueIg;
T rCheckI, rCheckF, rCheckO;
T rState, rPrevState = 0, rStateAtv;
T rOut;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. 后续要更改code style,到时候一起修改。Thanks!
T rCheckO; | ||
T rCheckIGrad; | ||
T rCheckFGrad; | ||
T rCheckOGrad; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,可以分组合并么。
下同。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。
// sort sequence index by the length. | ||
// example: sequences = {s0, s1, s2} | ||
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 | ||
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
51行是怎么出来的,注释能更详细点么
if (!is_reverse) { | ||
seq2batch_idx[batch_id] = start + n; | ||
} else { | ||
seq2batch_idx[batch_id] = start + seq_len - 1 - n; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq2batch_idx[batch_id]=is_reverse? start + seq_len - 1 - n: start + n;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Please note #4952 has been merged |
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value, | ||
frame_size, cur_batch_size, | ||
gate_act, cell_act, cand_act); | ||
lstm_value.prevStateValue = lstm_value.stateValue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是故意这么做的么? lstm_value.prevStateValue永远等于 lstm_value.stateValue ?
因为lstm_value.stateValue没有被compute函数修改过。这一行如果移动到第117行之前效果也是一样的。
#ifndef __NVCC__ | ||
|
||
template <class T, class Op> | ||
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是值传递,故调用这个函数的value永远不会修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,value里存的是T*,里面的值会变, 下个PR中将value本身作为引用传递吧。
#ifndef __AVX__ | ||
static const bool avx = false; | ||
#else | ||
static const bool avx = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来如果T是float, avx也应该等于false。
是不是double类型在单测里没测过?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测是double类型,在AVX函数调用的地方有类型判断: std::is_same<T, float>::value
if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, ...)
}
HL_ACTIVATION_RELU = 1, | ||
HL_ACTIVATION_TANH = 2, | ||
HL_ACTIVATION_LINEAR = 3, | ||
HL_ACTIVATION_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HL_ACTIVATION_END
is not needed.
另外,即使需要,也叫做 NUM_OF_ACTIVATIONS
比较好。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! 后续PR修改。
Please review qingqing01#3 for several enhancements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically LGTM. However, the develop branch should be merged.
Do not let me block to merge this PR. Please let other approve if I am not online after merging the develop branch.
Several Enhancement
@reyoung Thanks very much for your enhancements! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fix #4629
fix #4675
math/detail
.SequenceToBatch
functor to reorganize and sort the input sequence, I'll try to useTensorArray
in the future.Directory structure: