Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactorize framework/*.proto #3322

Merged
merged 45 commits into from
Aug 14, 2017
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e0e9a81
Update CMakeLists
wangkuiyi Aug 8, 2017
662aeed
Update operator/CMakeLists.txt
wangkuiyi Aug 8, 2017
72e3ba5
update framework.proto
wangkuiyi Aug 8, 2017
7e83011
Try make pass
reyoung Aug 8, 2017
dba618c
Make Compile Pass
reyoung Aug 8, 2017
d97a2b4
Merge pull request #3 from reyoung/feature/refactorize_framework_proto
wangkuiyi Aug 8, 2017
9544068
Resovle conflicts manually
wangkuiyi Aug 8, 2017
4a78885
Add a temporary test case otherwise there would be linking error with…
wangkuiyi Aug 8, 2017
7702ab1
Merge branch 'develop' of github.com:baidu/Paddle into feature/refact…
reyoung Aug 9, 2017
3d4ba22
Merge branch 'develop' of github.com:baidu/Paddle into feature/refact…
reyoung Aug 9, 2017
b368c6c
Rename op_proto_name/var_names -> parameter/arguments
reyoung Aug 9, 2017
5a59111
Modify rnn op unit test after refactoring framework proto.
qingqing01 Aug 9, 2017
78af6e6
Add OutputVars method to get all outputs or outputs without intermediate
reyoung Aug 9, 2017
030f430
Merge branch 'develop' of github.com:baidu/Paddle into feature/refact…
reyoung Aug 9, 2017
665e1a3
Update grad_op_builder after refactoring framework proto.
qingqing01 Aug 9, 2017
5f6e5ed
Merge pull request #7 from qingqing01/grad_op_builder
wangkuiyi Aug 9, 2017
36709d0
Merge pull request #5 from qingqing01/rnn_test_for_refactorize_framew…
wangkuiyi Aug 10, 2017
7202f42
Merge branch 'refactorize_framework_proto' into feature/refactorize_f…
qingqing01 Aug 10, 2017
c7e8c1a
Merge pull request #6 from reyoung/feature/refactorize_framework_proto
reyoung Aug 10, 2017
71acaff
Tiny fix
qingqing01 Aug 10, 2017
7fab7dd
Merge branch 'develop' of github.com:baidu/Paddle into feature/refact…
reyoung Aug 10, 2017
5ac3641
Merge pull request #8 from reyoung/feature/refactorize_framework_proto
reyoung Aug 10, 2017
0f84bb3
Fix merge error
reyoung Aug 10, 2017
0515d40
Merge pull request #9 from reyoung/feature/refactorize_framework_proto
qingqing01 Aug 10, 2017
ac5893e
Fix grad_op_builder
qingqing01 Aug 10, 2017
8810490
update code
qingqing01 Aug 10, 2017
f485815
Merge pull request #10 from qingqing01/framework_proto
wangkuiyi Aug 10, 2017
c99f84a
Fix python unit tests
reyoung Aug 11, 2017
133a8ea
Polish Error message
reyoung Aug 11, 2017
dfb4ea7
make unit test of backward_test pass.
qingqing01 Aug 11, 2017
aad49da
Merge remote-tracking branch 'wangkuiyi/refactorize_framework_proto' …
qingqing01 Aug 11, 2017
5422776
Merge pull request #12 from qingqing01/framework_proto
wangkuiyi Aug 11, 2017
96fc9e7
Merge pull request #11 from reyoung/fix_python_tests
wangkuiyi Aug 11, 2017
d6d4641
Merge branch 'develop' of github.com:baidu/Paddle into final_fixes
reyoung Aug 12, 2017
610a258
Fix all unit tests in Python
reyoung Aug 12, 2017
509d320
Fix CI and style
reyoung Aug 12, 2017
0b1052f
Get `DEFINE_OPERATOR_CTOR` Back to code
reyoung Aug 12, 2017
88a3d8d
Merge pull request #13 from reyoung/final_fixes
qingqing01 Aug 14, 2017
4a604c2
Polish Our code by YuYang's review
reyoung Aug 14, 2017
ef29b52
Simplify unit test code
reyoung Aug 14, 2017
f09cb65
Follow comments from WangYi
reyoung Aug 14, 2017
1ed5f02
Merge pull request #14 from reyoung/feature/refactorize_framework_proto
reyoung Aug 14, 2017
63b2e45
Fix CI Test
reyoung Aug 14, 2017
8371f67
Merge branch 'refactorize_framework_proto' of /~https://github.com/wang…
reyoung Aug 14, 2017
64a4dfe
Fix CI
reyoung Aug 14, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

proto_library(attribute_proto SRCS attribute.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto)
proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
proto_library(framework_proto SRCS framework.proto)

cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)

cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute)
cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)

py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto)
py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS;
}

Attribute GetAttrValue(const AttrDesc& attr_desc) {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: {
return attr_desc.i();
Expand Down
5 changes: 2 additions & 3 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>

#include "paddle/framework/attribute.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"

Expand All @@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template <typename T>
AttrType AttrTypeID();

Attribute GetAttrValue(const AttrDesc& attr_desc);
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

// check whether a value(attribute) fit a certain limit
template <typename T>
Expand Down
28 changes: 0 additions & 28 deletions paddle/framework/attribute.proto

This file was deleted.

65 changes: 41 additions & 24 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@
namespace paddle {
namespace framework {

static bool AllInSet(const std::vector<std::string>& names,
const std::string& suffix,
const std::unordered_set<std::string>& set) {
template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) {
for (auto& name : names) {
if (set.find(name + suffix) == set.end()) {
return false;
for (auto& n : name.second) {
if (callback(n)) return;
}
}
return true;
}

static bool AllInSet(
const std::map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) {
bool all_in_set = true;
ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
all_in_set = set.find(n + suffix) != set.end();
return !all_in_set;
});
return all_in_set;
}

static std::shared_ptr<OperatorBase> NOP() {
Expand Down Expand Up @@ -68,10 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + kGradVarSuffix);
}
ForEachVarName(forwardOp.inputs_,
[&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name));
return false;
});
return NOP();
}

Expand All @@ -93,9 +103,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id);
}
ForEachVarName(bwd->outputs_,
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
}
// Get unique ID for this method.
auto uid = uniq_id++;
Expand All @@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {dup_outputs}, {name},
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input_format is out of use now. It can be removed.

}
Expand All @@ -131,7 +143,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(

} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {

ForEachVarName(grad_op->inputs_, [&no_grad_names,
&net](std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
// +1 for \0
std::string prefix = grad_input.substr(
Expand All @@ -140,16 +154,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(

// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {}));
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
}
}

for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
}
return false;
});

ForEachVarName(grad_op->outputs_,
[&no_grad_names](std::string& grad_output) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
return false;
});

if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
Expand Down
Loading