From 29d892c13cf88c7659647cec532169caa7abd2b9 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 16 Aug 2017 14:19:38 +0800 Subject: [PATCH 1/4] Add Clone Method For OperatorBase * Clone method will create a new object instance, which is as same as itself. * This is the first step to remove shared_ptr for OperatorBase --- paddle/framework/op_registry.h | 15 +++++++++++++-- paddle/framework/operator.h | 14 ++++++++++---- paddle/framework/operator_test.cc | 19 +++++++++++++++++++ paddle/operators/net_op.cc | 7 +++++++ paddle/operators/net_op.h | 13 +++++++++++++ paddle/operators/net_op_test.cc | 17 +++++++++++++++++ paddle/operators/recurrent_op.h | 22 ++++++++++++++++++---- 7 files changed, 97 insertions(+), 10 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 3b793628aa6fd..b5b4668074cc7 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -271,7 +271,13 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP(op_type, op_class, op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \ + op_maker_class> \ __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ @@ -285,7 +291,12 @@ class OpKernelRegistrar : public Registrar { STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_gradient_op__##op_type##_##grad_op_type, \ "REGISTER_GRADIENT_OP must be called in global namespace"); \ - static ::paddle::framework::GradOpRegistrar \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ #grad_op_type); \ int TouchOpGradientRegistrar_##op_type() { \ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4a72ced6ced92..92032478661be 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -69,10 +69,6 @@ class OperatorBase { OperatorBase(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const AttributeMap& attrs); - OperatorBase(const OperatorBase& o) = delete; - OperatorBase& operator=(const OperatorBase& o) = delete; - OperatorBase(OperatorBase&& o) = delete; - virtual ~OperatorBase() {} template @@ -115,6 +111,8 @@ class OperatorBase { std::string Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } + virtual OperatorBase* Clone() const = 0; + public: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: @@ -129,6 +127,14 @@ class OperatorBase { AttributeMap attrs_; }; +#define DEFINE_OP_CLONE_METHOD(CLS) \ + OperatorBase* Clone() const final { return new CLS(*this); } + +#define DEFINE_OP_CTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, const VarNameMap& inputs, \ + const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ + : PARENT_CLS(type, inputs, outputs, attrs) {} + class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 6804841587730..ceba7f5e6ea48 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -242,3 +242,22 @@ TEST(OpKernel, multi_inputs) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } + +class OperatorClone : public paddle::framework::OperatorBase { + public: + DEFINE_OP_CLONE_METHOD(OperatorClone); + OperatorClone(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const paddle::framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const paddle::framework::Scope& scope) const override {} + void Run(const paddle::framework::Scope& scope, + const paddle::platform::DeviceContext& dev_ctx) const override {} +}; + +TEST(Operator, Clone) { + OperatorClone a("ABC", {}, {}, {}); + auto* b = a.Clone(); + ASSERT_EQ(a.Type(), b->Type()); + delete b; +} \ No newline at end of file diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 1d1b290440ec1..896550f9d0c55 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -87,5 +87,12 @@ NetOp::NetOp(const std::string& type, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} +framework::OperatorBase* NetOp::Clone() const { + PADDLE_ENFORCE( + add_op_done_, + "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); + return new NetOp(*this); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 4a3408c158a02..deee543065ab7 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase { NetOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + NetOp(const NetOp& o) + : framework::OperatorBase( + static_cast(o)) { + this->ops_.reserve(o.ops_.size()); + std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::shared_ptr& op) + -> std::shared_ptr { + return std::shared_ptr(op->Clone()); + }); + this->CompleteAddOp(); + } + /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch @@ -97,6 +109,7 @@ class NetOp : public framework::OperatorBase { bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; + framework::OperatorBase* Clone() const override; std::vector> ops_; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index f7aa56262ef71..40e43f46df79f 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -13,6 +13,7 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -23,6 +24,7 @@ class TestOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(EmptyOp); void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} }; @@ -77,5 +79,20 @@ TEST(NetOp, insert_op) { ASSERT_EQ(3UL, net.ops_.size()); } +TEST(NetOp, Clone) { + NetOp net; + net.AddOp(std::shared_ptr(new EmptyOp{"empty", {}, {}, {}})); + net.AddOp(std::shared_ptr(new EmptyOp{"empty2", {}, {}, {}})); + net.CompleteAddOp(true); + auto* new_net_op = net.Clone(); + ASSERT_NE(new_net_op, nullptr); + ASSERT_TRUE(new_net_op->IsNetOp()); + auto* new_net = static_cast(new_net_op); + ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); + ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); + delete new_net; +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 8f4f2444d844b..cc40eff0cf94c 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -99,13 +99,20 @@ class RecurrentGradientAlgorithm { mutable size_t seq_len_; }; -class RecurrentOp final : public framework::OperatorBase { +class RecurrentOp : public framework::OperatorBase { public: RecurrentOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + + RecurrentOp(const RecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } /** - * InferShape must be called before Run. - */ + * InferShape must be called before Run. + */ void InferShape(const framework::Scope& scope) const override { alg_.InferShape(scope); } @@ -121,12 +128,19 @@ class RecurrentOp final : public framework::OperatorBase { RecurrentAlgorithm alg_; }; -class RecurrentGradientOp final : public framework::OperatorBase { +class RecurrentGradientOp : public framework::OperatorBase { public: RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentGradientOp(const RecurrentGradientOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement Copy ctor. + PADDLE_THROW("Not Implemented"); + } + /** * InferShape must be called before Run. */ From 3e52343dc1c31d0c23a6fdcdee0c7c0492310014 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 16 Aug 2017 14:24:10 +0800 Subject: [PATCH 2/4] Add comments --- paddle/framework/operator.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 92032478661be..9e4d0d5e398d7 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -111,6 +111,8 @@ class OperatorBase { std::string Type() const { return type_; } const AttributeMap& Attrs() const { return attrs_; } + // Return a new operator instance, which is as same as this. + // NOTE: It is caller's responsibility to delete that operator instance. virtual OperatorBase* Clone() const = 0; public: @@ -127,9 +129,16 @@ class OperatorBase { AttributeMap attrs_; }; +// Macro for define a clone method. +// If you are writing an kernel operator, `Clone` will be defined when you +// register it. #define DEFINE_OP_CLONE_METHOD(CLS) \ OperatorBase* Clone() const final { return new CLS(*this); } +// Macro for define a default constructor for Operator. +// You can also use +// using PARENT_CLASS::PARENT_CLASS; +// to use parent's constructor. #define DEFINE_OP_CTOR(CLS, PARENT_CLS) \ CLS(const std::string& type, const VarNameMap& inputs, \ const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ From a0d77533f01c5da0fa811d4cc91235f5610f745f Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 16 Aug 2017 14:49:18 +0800 Subject: [PATCH 3/4] Rename Ctor -> Constructor Make code more clearer --- paddle/framework/op_registry.h | 4 ++-- paddle/framework/operator.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index b5b4668074cc7..c0654b375dd09 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -274,7 +274,7 @@ class OpKernelRegistrar : public Registrar { class _OpClass_##op_type##_ : public op_class { \ public: \ DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ - DEFINE_OP_CTOR(_OpClass_##op_type##_, op_class); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ }; \ static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \ op_maker_class> \ @@ -294,7 +294,7 @@ class OpKernelRegistrar : public Registrar { class _OpGradClass_##op_type##_ : public grad_op_class { \ public: \ DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ - DEFINE_OP_CTOR(_OpGradClass_##op_type##_, grad_op_class); \ + DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ }; \ static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \ __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 9e4d0d5e398d7..4a1dee6fb00d7 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -118,7 +118,7 @@ class OperatorBase { public: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs) + // I (Inputs)opear // O (Outputs) // OG (Output Gradients) VarNameMap inputs_; @@ -139,7 +139,7 @@ class OperatorBase { // You can also use // using PARENT_CLASS::PARENT_CLASS; // to use parent's constructor. -#define DEFINE_OP_CTOR(CLS, PARENT_CLS) \ +#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ CLS(const std::string& type, const VarNameMap& inputs, \ const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ : PARENT_CLS(type, inputs, outputs, attrs) {} From 1425387570d5559ad0e82bd690b0fcc424911ca1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 16 Aug 2017 15:52:48 +0800 Subject: [PATCH 4/4] Using unique_ptr instead of raw ptr Fit google C++ style --- paddle/framework/operator.h | 10 ++++++---- paddle/framework/operator_test.cc | 3 +-- paddle/operators/net_op.cc | 6 +++--- paddle/operators/net_op.h | 3 ++- paddle/operators/net_op_test.cc | 5 ++--- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4a1dee6fb00d7..9e8aef6f8563b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -112,8 +112,8 @@ class OperatorBase { const AttributeMap& Attrs() const { return attrs_; } // Return a new operator instance, which is as same as this. - // NOTE: It is caller's responsibility to delete that operator instance. - virtual OperatorBase* Clone() const = 0; + // Use unique_ptr to prevent caller forget to delete this pointer. + virtual std::unique_ptr Clone() const = 0; public: std::string type_; @@ -132,8 +132,10 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. -#define DEFINE_OP_CLONE_METHOD(CLS) \ - OperatorBase* Clone() const final { return new CLS(*this); } +#define DEFINE_OP_CLONE_METHOD(CLS) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new CLS(*this)); \ + } // Macro for define a default constructor for Operator. // You can also use diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index ceba7f5e6ea48..8836217126581 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -257,7 +257,6 @@ class OperatorClone : public paddle::framework::OperatorBase { TEST(Operator, Clone) { OperatorClone a("ABC", {}, {}, {}); - auto* b = a.Clone(); + auto b = a.Clone(); ASSERT_EQ(a.Type(), b->Type()); - delete b; } \ No newline at end of file diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 896550f9d0c55..77eb07e2f9018 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -85,13 +85,13 @@ NetOp::NetOp(const std::string& type, const framework::OperatorBase::VarNameMap& inputs, const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : framework::OperatorBase(type, inputs, outputs, attrs) {} -framework::OperatorBase* NetOp::Clone() const { +std::unique_ptr NetOp::Clone() const { PADDLE_ENFORCE( add_op_done_, "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); - return new NetOp(*this); + return std::unique_ptr(new NetOp(*this)); } } // namespace operators diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index deee543065ab7..743f0e67dbeaa 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -109,7 +109,8 @@ class NetOp : public framework::OperatorBase { bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; - framework::OperatorBase* Clone() const override; + + std::unique_ptr Clone() const override; std::vector> ops_; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 40e43f46df79f..6d6f8bd354817 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -84,14 +84,13 @@ TEST(NetOp, Clone) { net.AddOp(std::shared_ptr(new EmptyOp{"empty", {}, {}, {}})); net.AddOp(std::shared_ptr(new EmptyOp{"empty2", {}, {}, {}})); net.CompleteAddOp(true); - auto* new_net_op = net.Clone(); + auto new_net_op = net.Clone(); ASSERT_NE(new_net_op, nullptr); ASSERT_TRUE(new_net_op->IsNetOp()); - auto* new_net = static_cast(new_net_op); + auto* new_net = static_cast(new_net_op.get()); ASSERT_EQ(2, new_net->ops_.size()); ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); - delete new_net; } } // namespace operators