-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split operator with CPU kernel #4046
Conversation
@@ -185,10 +185,9 @@ def check_output_with_place(self, place): | |||
for out_name, out_dup in Operator.get_op_outputs(self.op.type()): | |||
if out_dup: | |||
sub_out = self.outputs[out_name] | |||
for sub_out_name in sub_out: | |||
for sub_out_name, expect in sub_out: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for sub_out_name, expect in sub_out.iteritems():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sub_out
a tuple list, so maybe no iteritems
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add some type checking code, so that if something went wrong, the python traceback message will be more readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self.op_type = "split" | ||
axis = 0 | ||
indices = 2 | ||
sections = [1, 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems sections
not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -185,10 +185,9 @@ def check_output_with_place(self, place): | |||
for out_name, out_dup in Operator.get_op_outputs(self.op.type()): | |||
if out_dup: | |||
sub_out = self.outputs[out_name] | |||
for sub_out_name in sub_out: | |||
for sub_out_name, expect in sub_out: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add some type checking code, so that if something went wrong, the python traceback message will be more readable
paddle/operators/split_op.cc
Outdated
size_t num = static_cast<size_t>(ctx.Attr<int>("num")); | ||
std::vector<int> sections = | ||
static_cast<std::vector<int>>(ctx.Attr<std::vector<int>>("sections")); | ||
size_t n = outs.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t n
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
AddInput("X", "the input tensor of split operator."); | ||
AddOutput("Out", "the output tensors of split operator.").AsDuplicable(); | ||
AddComment(R"DOC( | ||
Split the input tensor into multiple sub-tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should also have some descriptions about attributes "num".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/split_op.h
Outdated
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); | ||
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); | ||
size_t before = 1, after = 1; | ||
size_t n = outs.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t n
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/split_op.h
Outdated
size_t n = outs.size(); | ||
size_t input_axis_dim = 0; | ||
for (size_t i = 0; i < n; i++) { | ||
input_axis_dim += outs[i]->dims()[axis]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use in->dims()[axis]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! It's my mistake.
: NetOp(type, inputs, outputs, attrs) { | ||
auto out_grad = Inputs(framework::GradVarName("Out")); | ||
auto x_grad = Output(framework::GradVarName("X")); | ||
AppendOp(framework::OpRegistry::CreateOp("concat", {{"X", out_grad}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call the kernel directly instead of using netOP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's OK, but it's better we have a common solution in #4099 , and I will fix this in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM++
Releated issue #3929