-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reshape operator #3949
Add reshape operator #3949
Changes from 2 commits
12eaa22
899c7d6
02da0d1
dd64349
477d92b
7ae72f7
31cbb34
dd92649
0289a00
5915138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
|
||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/reshape_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ReshapeOp : public framework::OperatorWithKernel { | ||
public: | ||
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
PADDLE_ENFORCE_EQ((unsigned)shape.size(), in->dims().size(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed this unreasonable line. |
||
"The dimension of Input(X) mismatches with Attr(shape)."); | ||
size_t shape_size = 1; | ||
for (auto dim : shape) { | ||
shape_size *= dim; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int64_t capacity = 1;
for (auto dim : shape) {
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive.");
capacity *= dim;
} or use int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
size_t in_size = framework::product(in->dims()); | ||
PADDLE_ENFORCE_EQ(shape_size, in_size, | ||
"The size of Input(X) mismatches with Attr(shape)."); | ||
ctx.Output<framework::Tensor>("Out")->Resize(in->dims()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified |
||
} | ||
}; | ||
|
||
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
ReshapeOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The input tensor of reshape operator."); | ||
AddOutput("Out", "The output tensor of reshape operator."); | ||
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator."); | ||
AddComment(R"DOC(Reshape operator | ||
|
||
Reshape Input(X) into the shape specified by Attr(shape). | ||
)DOC"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need more comments : /~https://github.com/PaddlePaddle/Paddle/pull/3765/files#diff-6e4a3431c20e09cd11b1b56451b5754eR52 doc要求: #3885 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
}; | ||
|
||
class ReshapeGradOp : public framework::OperatorWithKernel { | ||
public: | ||
ReshapeGradOp(const std::string &type, | ||
const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to check nonempty for the inputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
auto dims = ctx.Input<framework::Tensor>("X")->dims(); | ||
auto *d_in = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
d_in->Resize(dims); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad, | ||
ops::ReshapeGradOp); | ||
REGISTER_OP_CPU_KERNEL(reshape, | ||
ops::ReshapeKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL( | ||
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/reshape_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL( | ||
reshape, | ||
paddle::operators::ReshapeKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL( | ||
reshape_grad, | ||
paddle::operators::ReshapeGradKernel<paddle::platform::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
|
||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
template <typename Place, typename T> | ||
class ReshapeKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* out = ctx.Output<Tensor>("Out"); | ||
auto* in = ctx.Input<Tensor>("X"); | ||
out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
std::vector<int64_t> tmp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not mix use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
for (auto dim : shape) { | ||
tmp.push_back(dim); | ||
} | ||
auto out_dims = framework::make_ddim(tmp); | ||
out->CopyFrom<T>(*in, ctx.GetPlace()); | ||
out->Resize(out_dims); | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class ReshapeGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
d_x->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto in_dims = d_x->dims(); | ||
d_x->CopyFrom<T>(*d_out, ctx.GetPlace()); | ||
d_x->Resize(in_dims); | ||
} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import unittest | ||
import numpy as np | ||
from gradient_checker import GradientChecker, Operator | ||
from op_test_util import OpTestMeta | ||
|
||
|
||
class TestReshapeOp(unittest.TestCase): | ||
__metaclass__ = OpTestMeta | ||
|
||
def setUp(self): | ||
self.type = "reshape" | ||
self.inputs = {'X': np.random.random((37, 51)).astype("float32"), } | ||
self.attrs = {'shape': [51, 37]} | ||
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} | ||
|
||
|
||
class ReshapeGradOpTest(GradientChecker): | ||
def test_normal(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now can use the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
op = Operator("reshape", X='X', Out='Out', shape=[5, 40]) | ||
inputs = {"X": np.random.random((10, 20)).astype("float32")} | ||
self.check_grad(op, inputs, set("X"), "Out") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check non empty for Input('X'), like : /~https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/squared_l2_distance_op.cc#L26
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done