Skip to content

Commit

Permalink
Remove empty constructor for operator
Browse files Browse the repository at this point in the history
  • Loading branch information
reyoung committed Aug 12, 2017
1 parent 0b1052f commit 11c3560
Show file tree
Hide file tree
Showing 24 changed files with 158 additions and 116 deletions.
7 changes: 4 additions & 3 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;

class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
Expand Down Expand Up @@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {

class FcOp : public operators::NetOp {
public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override {
FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
Expand Down
34 changes: 20 additions & 14 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ class OpRegistry;

enum class OpArgType { IN, OUT };

static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
bool is_grad) {
static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
auto& dst_inout = *vars;

const OpProto& proto = OpProtos().at(src_op->type_);
const auto& src_arg_list =
Expand All @@ -47,15 +46,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}

OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG
return grad_op;
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_);
auto& grad_op_type = gop_type_it->second;
OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);

return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
}

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace framework {

class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
Expand Down
46 changes: 17 additions & 29 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,19 @@ class OpProtoAndCheckerMaker {
};

class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
Expand All @@ -138,29 +144,25 @@ class OpRegistry {
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type;
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
AttributeMap attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);

auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;

op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);

GenerateTempVariableName(op);
auto op = op_create_it->second(type, inputs, outputs, attrs);

op->Init();
return std::shared_ptr<OperatorBase>(op);
}

Expand Down Expand Up @@ -195,7 +197,6 @@ class OpRegistry {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
grad_op->Init();
return grad_op;
}

Expand All @@ -214,19 +215,6 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}

static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += op->type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
};

class Registrar {
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
Expand All @@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
Expand Down
16 changes: 16 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,21 @@ void OperatorBase::Rename(const std::string& old_name,
}
}

OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
} // namespace framework
} // namespace paddle
27 changes: 7 additions & 20 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ class OperatorBase {
public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;

OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
const VarNameMap& outputs, const AttributeMap& attrs);

OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
Expand All @@ -86,10 +84,6 @@ class OperatorBase {

virtual std::string DebugString() const;

/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}

/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const Scope& scope) const = 0;
Expand Down Expand Up @@ -154,23 +148,14 @@ class OperatorBase {
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::map<std::string, std::vector<std::string>> inputs_;
VarNameMap inputs_;

// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::map<std::string, std::vector<std::string>> outputs_;
VarNameMap outputs_;
AttributeMap attrs_;
};

#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}

class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
Expand Down Expand Up @@ -310,8 +295,6 @@ class OpKernel {

class OperatorWithKernel : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)

struct OpKernelKey {
platform::Place place_;

Expand All @@ -335,6 +318,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
}
Expand Down
12 changes: 7 additions & 5 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ namespace framework {
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)

public:
void Init() override { x = 1; }
OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
Expand All @@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
}

public:
float x = 0;
int x{0};
};

class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -104,7 +104,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;

class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
public:
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
Expand Down
7 changes: 5 additions & 2 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace paddle {
namespace operators {

class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
Expand All @@ -45,7 +46,9 @@ The equation is: Out = X + Y
};

class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
Expand Down
9 changes: 6 additions & 3 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace paddle {
namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
Expand All @@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};

class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Expand Down
3 changes: 2 additions & 1 deletion paddle/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace paddle {
namespace operators {

class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
Expand Down
3 changes: 2 additions & 1 deletion paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};

class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& context) const override {
Expand Down
8 changes: 6 additions & 2 deletions paddle/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace paddle {
namespace operators {

class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
Expand All @@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};

class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
Expand Down
Loading

0 comments on commit 11c3560

Please sign in to comment.