-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reshape operator #3949
Add reshape operator #3949
Changes from 9 commits
12eaa22
899c7d6
02da0d1
dd64349
477d92b
7ae72f7
31cbb34
dd92649
0289a00
5915138
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
|
||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/reshape_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ReshapeOp : public framework::OperatorWithKernel { | ||
public: | ||
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
// input check | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null"); | ||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); | ||
for (auto dim : shape) { | ||
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); | ||
} | ||
// capacity check | ||
int64_t capacity = | ||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
int64_t in_size = framework::product(in->dims()); | ||
PADDLE_ENFORCE_EQ(capacity, in_size, | ||
"The size of Input(X) mismatches with Attr(shape)."); | ||
// resize output | ||
std::vector<int64_t> shape_int64(shape.size(), 0); | ||
std::transform(shape.begin(), shape.end(), shape_int64.begin(), | ||
[](int a) { return static_cast<int64_t>(a); }); | ||
auto out_dims = framework::make_ddim(shape_int64); | ||
ctx.Output<framework::Tensor>("Out")->Resize(out_dims); | ||
} | ||
}; | ||
|
||
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
ReshapeOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "The input tensor of reshape operator."); | ||
AddOutput("Out", "The output tensor of reshape operator."); | ||
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator."); | ||
AddComment(R"DOC(Reshape operator | ||
|
||
Reshape Input(X) into the shape specified by Attr(shape). | ||
|
||
An example: | ||
Given a 2-D tensor X with 2 rows and 2 columns | ||
|
||
[[1, 2], [3, 4]] | ||
|
||
with target shape = [1, 4], the reshape operator will tansform | ||
the tensor X into a 1-D tensor: | ||
|
||
[1, 2, 3, 4] | ||
|
||
)DOC"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need more comments : /~https://github.com/PaddlePaddle/Paddle/pull/3765/files#diff-6e4a3431c20e09cd11b1b56451b5754eR52 doc要求: #3885 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
}; | ||
|
||
class ReshapeGradOp : public framework::OperatorWithKernel { | ||
public: | ||
ReshapeGradOp(const std::string &type, | ||
const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to check nonempty for the inputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), | ||
"Input(Out@GRAD) shouldn't be null."); | ||
auto dims = ctx.Input<framework::Tensor>("X")->dims(); | ||
auto *d_in = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
d_in->Resize(dims); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad, | ||
ops::ReshapeGradOp); | ||
REGISTER_OP_CPU_KERNEL(reshape, | ||
ops::ReshapeKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL( | ||
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/reshape_op.h" | ||
|
||
REGISTER_OP_GPU_KERNEL( | ||
reshape, | ||
paddle::operators::ReshapeKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL( | ||
reshape_grad, | ||
paddle::operators::ReshapeGradKernel<paddle::platform::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
class ReshapeKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* out = ctx.Output<framework::Tensor>("Out"); | ||
auto* in = ctx.Input<framework::Tensor>("X"); | ||
out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto shape = ctx.Attr<std::vector<int>>("shape"); | ||
std::vector<int64_t> shape_int64(shape.size(), 0); | ||
std::transform(shape.begin(), shape.end(), shape_int64.begin(), | ||
[](int a) { return static_cast<int64_t>(a); }); | ||
auto out_dims = framework::make_ddim(shape_int64); | ||
out->CopyFrom<T>(*in, ctx.GetPlace()); | ||
out->Resize(out_dims); | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class ReshapeGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); | ||
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||
d_x->mutable_data<T>(ctx.GetPlace()); | ||
|
||
auto in_dims = d_x->dims(); | ||
d_x->CopyFrom<T>(*d_out, ctx.GetPlace()); | ||
d_x->Resize(in_dims); | ||
} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
class TestReshapeOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "reshape" | ||
self.inputs = {'X': np.random.random((10, 20)).astype("float32")} | ||
self.attrs = {'shape': [10 * 20]} | ||
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(["X"], "Out") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or use
std::accumulate
:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done