-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reshape operator #3949
Add reshape operator #3949
Conversation
paddle/operators/reshape_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
PADDLE_ENFORCE_EQ((unsigned)shape.size(), in->dims().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsigned long
please. the size returned by ddim is ssize_t, which is 8byte width.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this unreasonable line.
paddle/operators/reshape_op.h
Outdated
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not use using
in the header file, according to the google style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
|
||
class ReshapeGradOpTest(GradientChecker): | ||
def test_normal(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now can use the function setUp
to generate forward operator.
please take a look at this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.h
Outdated
out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
std::vector<int64_t> tmp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not mix use of int64_t
and size_t
, which is ugly for users who read our code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refine this PR by following the review comments. Please continue to review.
paddle/operators/reshape_op.cc
Outdated
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
PADDLE_ENFORCE_EQ((unsigned)shape.size(), in->dims().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this unreasonable line.
paddle/operators/reshape_op.h
Outdated
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.h
Outdated
out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
std::vector<int64_t> tmp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
|
||
class ReshapeGradOpTest(GradientChecker): | ||
def test_normal(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems LGTM.
Maybe @qingqing01 have more comment.
paddle/operators/reshape_op.cc
Outdated
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check non empty for Input('X'), like : /~https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/squared_l2_distance_op.cc#L26
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.cc
Outdated
} else { | ||
capacity *= dim; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int64_t capacity = 1;
for (auto dim : shape) {
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive.");
capacity *= dim;
}
or use std::accumulate
:
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.cc
Outdated
int64_t in_size = framework::product(in->dims()); | ||
PADDLE_ENFORCE_EQ(capacity, in_size, | ||
"The size of Input(X) mismatches with Attr(shape)."); | ||
ctx.Output<framework::Tensor>("Out")->Resize(in->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not Resize(shape)
? The dims of output is not shape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
paddle/operators/reshape_op.cc
Outdated
AddComment(R"DOC(Reshape operator | ||
|
||
Reshape Input(X) into the shape specified by Attr(shape). | ||
)DOC"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need more comments : /~https://github.com/PaddlePaddle/Paddle/pull/3765/files#diff-6e4a3431c20e09cd11b1b56451b5754eR52
doc要求: #3885 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check nonempty for the inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated, please continue to review @qingqing01
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.cc
Outdated
int64_t in_size = framework::product(in->dims()); | ||
PADDLE_ENFORCE_EQ(capacity, in_size, | ||
"The size of Input(X) mismatches with Attr(shape)."); | ||
ctx.Output<framework::Tensor>("Out")->Resize(in->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
paddle/operators/reshape_op.cc
Outdated
} else { | ||
capacity *= dim; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.cc
Outdated
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/operators/reshape_op.cc
Outdated
AddComment(R"DOC(Reshape operator | ||
|
||
Reshape Input(X) into the shape specified by Attr(shape). | ||
)DOC"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Resolve #4009