-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add SQRT/LAST/FIRST strategy for Seqpool #4788
Conversation
paddle/operators/sequence_pool_op.cc
Outdated
AddOutput( | ||
"Out", | ||
"A float LoDTensor, the variable-length output of SequencePoolOp."); | ||
AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AddInput("X",
"(LoDTensor), the variable-length input of SequencePoolOp");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/sequence_pool_op.cc
Outdated
"A float LoDTensor, the variable-length output of SequencePoolOp."); | ||
AddInput("X", "A LoDTensor, the variable-length input of SequencePoolOp"); | ||
AddOutput("Out", | ||
"A LoDTensor, the variable-length output of SequencePoolOp."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AddOutput("Out",
"(Tensor), the output of SequencePoolOp is a Tensor, which does not contain LoD infomation.");
这里支持双层吗? 输出还变长吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done。由于目前还没写到双层,这里先修改成Tensor。等以后加入双层特征后,再对应修改注释。
paddle/operators/sequence_pool_op.h
Outdated
@@ -98,6 +109,10 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { | |||
int64_t w = in->numel() / dims[0]; | |||
|
|||
in_g->mutable_data<T>(context.GetPlace()); | |||
if (strategy > 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ( strategy == LAST || strategy == FIRST || strategy == MAX) {
// ..
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done。Max的时候,也不需要先置0.(处理方式见下一个PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need more detailed doc.
AVERAGE, SUM, SQRT, MAX, LAST, FIRST
Except for the simple example, these arguments should also be explained.
I will add more explanation in next PR. |
solve part2 of #4186
MAX strategy will be in next PR.