Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/remove shared ptr #3538

Merged
merged 10 commits into from
Aug 17, 2017
42 changes: 20 additions & 22 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/framework/backward.h"

#include <list>
#include <memory>

#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
Expand Down Expand Up @@ -43,11 +45,11 @@ static bool AllInSet(
return all_in_set;
}

static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>();
static std::unique_ptr<OperatorBase> NOP() {
auto net_op = new operators::NetOp();
net_op->SetType("@NOP@");
net_op->CompleteAddOp();
return net_op;
return std::unique_ptr<OperatorBase>(net_op);
}

// Get backward operator from a forward operator, a recursive implementation.
Expand All @@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
// operator, in a complex situation, it maybe a NetOp.
//
// See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);

std::shared_ptr<OperatorBase> BackwardRecursive(
static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate,
Expand All @@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}

// Returned gradient network
auto net = std::make_shared<operators::NetOp>();
auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a make_unique helper ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it could be in another PR. make_unique is not in C++ 11 but in C++ 14. We can implement it in C++ 11.


if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
Expand All @@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto fwd = *it;
auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
net->AddOp(std::move(bwd));
}
// Get unique ID for this method.
auto uid = uniq_id++;
Expand All @@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first;
Expand Down Expand Up @@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
[](const Pos& l, const Pos& r) { return l.first > r.first; });

for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second);
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));

ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
grad_op](const std::string& grad_input) {
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
// +1 for \0
std::string prefix = grad_input.substr(
Expand Down Expand Up @@ -190,23 +188,23 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
}

if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op);
net->AddOp(std::move(grad_op));
}
net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp();
return net;
} // namespace framework
return std::unique_ptr<OperatorBase>(
static_cast<OperatorBase*>(net.release()));
}

// See header for comments
std::shared_ptr<OperatorBase> Backward(
std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names;
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace framework {

// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
Expand Down
3 changes: 1 addition & 2 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"x", "b"});
ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL,
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
}

TEST(Backward, net_fc_backward_normal) {
Expand Down
11 changes: 5 additions & 6 deletions paddle/framework/op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace framework {

std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
Expand All @@ -28,10 +28,10 @@ std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
return std::unique_ptr<OperatorBase>(op);
}

std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
Expand All @@ -55,10 +55,9 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
return ret_val;
}

std::shared_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
return grad_op;
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
}

} // namespace framework
Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ class OpRegistry {
}
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs);

static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);

static VarNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);

static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);

static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_;
Expand Down
6 changes: 2 additions & 4 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
Expand Down Expand Up @@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {

ASSERT_TRUE(op_desc.IsInitialized());

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
Expand Down
141 changes: 56 additions & 85 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,6 @@ namespace framework {

using Tensor = framework::Tensor;

template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.Type();
})
.def("outputs",
[](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs",
[](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs",
[](const typename ClassType::type &op) {
return op.OutputVars(false);
})
.def("support_gpu", &ClassType::type::SupportGPU);
}

static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
Expand Down Expand Up @@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");

operator_base.def_static("create", [](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
});

operator_base.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars);
});

ExposeOperator(operator_base);

py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");

net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
retv->SetType("plain_net");
return retv;
})
.def("add_op", &operators::NetOp::AddOp)
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);

py::class_<operators::NetOp, OperatorBase>(m, "Net")
.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});

ExposeOperator(net);

// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");

rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});

m.def("unique_integer", UniqueIntegerGenerator);

Expand Down
Loading