-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split operator with CPU kernel #4046
Changes from all commits
d372565
288eb8a
594b75d
4c9b5ee
fd8cac0
1f1ba22
f5777e1
0b2a22c
4897e89
1963d0a
1483dcf
c26671d
54b7ef8
b6f9034
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/split_op.h" | ||
#include "paddle/operators/net_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using framework::Tensor; | ||
|
||
class SplitOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
// infershape | ||
auto *in = ctx.Input<framework::Tensor>("X"); | ||
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); | ||
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis")); | ||
size_t num = static_cast<size_t>(ctx.Attr<int>("num")); | ||
std::vector<int> sections = | ||
static_cast<std::vector<int>>(ctx.Attr<std::vector<int>>("sections")); | ||
const size_t n = outs.size(); | ||
|
||
if (num > 0) { | ||
int64_t in_axis_dim = in->dims()[axis]; | ||
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, | ||
"tensor split does not result" | ||
" in an equal division"); | ||
size_t out_axis_dim = in_axis_dim / num; | ||
for (size_t i = 0; i < n; ++i) { | ||
auto dim = in->dims(); | ||
dim[axis] = out_axis_dim; | ||
outs[i]->Resize(dim); | ||
} | ||
} else if (sections.size() > 0) { | ||
PADDLE_ENFORCE_EQ(sections.size(), n, | ||
"tensor split sections size" | ||
"should be equal to output size."); | ||
for (size_t i = 0; i < n; ++i) { | ||
auto dim = in->dims(); | ||
dim[axis] = sections[i]; | ||
outs[i]->Resize(dim); | ||
} | ||
} else { | ||
PADDLE_ENFORCE_NOT_NULL(nullptr, "split operator should", | ||
" specify indices or sections."); | ||
} | ||
} | ||
}; | ||
|
||
class SplitOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
SplitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "the input tensor of split operator."); | ||
AddOutput("Out", "the output tensors of split operator.").AsDuplicable(); | ||
AddComment(R"DOC( | ||
Split the input tensor into multiple sub-tensors. | ||
Example: | ||
Input = [[1,2], | ||
[3,4], | ||
[5,6]] | ||
sections = [2,1] | ||
axis = 0 | ||
Output[0] = [[1,2], | ||
[3,4]] | ||
Output[1] = [[5,6]] | ||
|
||
)DOC"); | ||
AddAttr<std::vector<int>>("sections", | ||
"the length for each" | ||
"output along with the specify axis.") | ||
.SetDefault(std::vector<int>{}); | ||
AddAttr<int>("num", | ||
"number of the sub-tensors, it must evenly divide " | ||
"Input.dims()[axis]") | ||
.SetDefault(0); | ||
AddAttr<int>("axis", "The axis which the input will be splited on.") | ||
.SetDefault(0); | ||
} | ||
}; | ||
|
||
class SplitOpGrad : public NetOp { | ||
public: | ||
SplitOpGrad(const std::string &type, const framework::VariableNameMap &inputs, | ||
const framework::VariableNameMap &outputs, | ||
const framework::AttributeMap &attrs) | ||
: NetOp(type, inputs, outputs, attrs) { | ||
auto out_grad = Inputs(framework::GradVarName("Out")); | ||
auto x_grad = Output(framework::GradVarName("X")); | ||
AppendOp(framework::OpRegistry::CreateOp("concat", {{"X", out_grad}}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call the kernel directly instead of using netOP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's OK, but it's better we have a common solution in #4099 , and I will fix this in another PR. |
||
{{"Out", {x_grad}}}, attrs)); | ||
CompleteAddOp(false); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
USE_CPU_ONLY_OP(concat); | ||
REGISTER_OP(split, ops::SplitOp, ops::SplitOpMaker, split_grad, | ||
ops::SplitOpGrad); | ||
REGISTER_OP_CPU_KERNEL(split, | ||
ops::SplitKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
class SplitKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* in = ctx.Input<framework::Tensor>("X"); | ||
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); | ||
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); | ||
size_t before = 1, after = 1; | ||
const size_t n = outs.size(); | ||
size_t input_axis_dim = in->dims()[axis]; | ||
|
||
for (int64_t i = 0; i < in->dims().size(); ++i) { | ||
if (i == axis) { | ||
continue; | ||
} | ||
if (i < axis) { | ||
before *= in->dims()[i]; | ||
} else { | ||
after *= in->dims()[i]; | ||
} | ||
} | ||
size_t input_offset = 0; | ||
for (size_t i = 0; i < n; i++) { | ||
auto& out = outs[i]; | ||
size_t axis_dim = out->dims()[axis]; | ||
for (size_t j = 0; j < before; j++) { | ||
size_t len = axis_dim * after * sizeof(T); | ||
T* dest = | ||
out->mutable_data<T>(platform::CPUPlace()) + axis_dim * after * j; | ||
const T* src = | ||
in->data<T>() + input_offset + input_axis_dim * after * j; | ||
memcpy(dest, src, len); | ||
} | ||
input_offset += axis_dim * after; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
class TestSplitOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "split" | ||
axis = 0 | ||
num = 2 | ||
x = np.random.random((4, 2)).astype('float32') | ||
out = np.split(x, num, axis) | ||
self.inputs = {'X': x} | ||
self.attrs = {'axis': axis, 'num': num} | ||
self.outputs = {'Out': [('out%d' % i, out[i]) \ | ||
for i in xrange(len(out))]} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], ['out0', 'out1']) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also have some descriptions about attributes "num".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.