Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using constructor to create an operator. #3444

Merged
merged 2 commits into from
Aug 14, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;

class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
Expand Down Expand Up @@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {

class FcOp : public operators::NetOp {
public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override {
FcOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
Expand Down
34 changes: 20 additions & 14 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ class OpRegistry;

enum class OpArgType { IN, OUT };

static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
bool is_grad) {
static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
auto& dst_inout = *vars;

const OpProto& proto = OpProtos().at(src_op->type_);
const auto& src_arg_list =
Expand All @@ -47,15 +46,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}

OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG
return grad_op;
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_);
auto& grad_op_type = gop_type_it->second;
OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);

return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
}

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace framework {

class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
Expand Down
46 changes: 17 additions & 29 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,19 @@ class OpProtoAndCheckerMaker {
};

class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = std::map<std::string, std::vector<std::string>>;
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
Expand All @@ -138,29 +144,25 @@ class OpRegistry {
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type;
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
AttributeMap attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);

auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;

op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);

GenerateTempVariableName(op);
auto op = op_create_it->second(type, inputs, outputs, attrs);

op->Init();
return std::shared_ptr<OperatorBase>(op);
}

Expand Down Expand Up @@ -195,7 +197,6 @@ class OpRegistry {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
grad_op->Init();
return grad_op;
}

Expand All @@ -214,19 +215,6 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}

static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += op->type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
};

class Registrar {
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
Expand All @@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
Expand Down
16 changes: 16 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,21 @@ void OperatorBase::Rename(const std::string& old_name,
}
}

OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
} // namespace framework
} // namespace paddle
27 changes: 7 additions & 20 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ class OperatorBase {
public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;

OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
const VarNameMap& outputs, const AttributeMap& attrs);

OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
Expand All @@ -86,10 +84,6 @@ class OperatorBase {

virtual std::string DebugString() const;

/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}

/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const Scope& scope) const = 0;
Expand Down Expand Up @@ -154,23 +148,14 @@ class OperatorBase {
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
std::map<std::string, std::vector<std::string>> inputs_;
VarNameMap inputs_;

// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
std::map<std::string, std::vector<std::string>> outputs_;
VarNameMap outputs_;
AttributeMap attrs_;
};

#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}

class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
Expand Down Expand Up @@ -310,8 +295,6 @@ class OpKernel {

class OperatorWithKernel : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)

struct OpKernelKey {
platform::Place place_;

Expand All @@ -335,6 +318,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
}
Expand Down
12 changes: 7 additions & 5 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ namespace framework {
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)

public:
void Init() override { x = 1; }
OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
Expand All @@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
}

public:
float x = 0;
int x{0};
};

class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -104,7 +104,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;

class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
public:
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
Expand Down
7 changes: 5 additions & 2 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace paddle {
namespace operators {

class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
Expand All @@ -45,7 +46,9 @@ The equation is: Out = X + Y
};

class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
Expand Down
9 changes: 6 additions & 3 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace paddle {
namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
Expand All @@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};

class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Expand Down
3 changes: 2 additions & 1 deletion paddle/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace paddle {
namespace operators {

class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
Expand Down
3 changes: 2 additions & 1 deletion paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};

class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& context) const override {
Expand Down
8 changes: 6 additions & 2 deletions paddle/operators/mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ namespace paddle {
namespace operators {

class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
Expand All @@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};

class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
Expand Down
Loading