Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge maps in OpRegistry and simplify register macros #3436

Merged
merged 12 commits into from
Aug 14, 2017
23 changes: 10 additions & 13 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad,
f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker,
many_output_op_grad, f::EmptyOp);

TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp(
Expand Down
40 changes: 21 additions & 19 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@ namespace paddle {
namespace framework {
enum class OpArgType { IN, OUT };

static void TransOpArg(const OperatorBase* src_op,
OperatorBase::VarNameMap* vars,
const OpArgType& src_type, bool is_grad) {
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
bool is_grad, OperatorBase::VarNameMap* vars) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = *vars;

const OpProto& proto = OpProtos().at(src_op->type_);
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_;
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
Expand All @@ -43,22 +41,26 @@ static void TransOpArg(const OperatorBase* src_op,
}

OperatorBase* BuildGradOp(const OperatorBase* op) {
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
"Operator %s do not register gradient type", op->type_);
auto& grad_op_type = gop_type_it->second;
auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type_);
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->type_);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_);

OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);
TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, OpArgType::IN, true, &outputs); // IG

return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
return it->second.creator_(grad_op_type, inputs, outputs, op->attrs_);
}

} // namespace framework
Expand Down
14 changes: 2 additions & 12 deletions paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ USE_OP(add_two);
namespace paddle {
namespace framework {

class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};

class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Expand Down Expand Up @@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
}

REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);

TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
Expand Down
194 changes: 77 additions & 117 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attribute.h"
Expand Down Expand Up @@ -119,53 +120,69 @@ class OpProtoAndCheckerMaker {
bool validated_{false};
};

class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};

class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
op_proto.set_type(op_type);
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
}
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};

template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
"'%s' is registered more than once.", op_type);
OpInfo op_info;
op_info.creator_ = [](const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type;
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
op_info_map().insert(std::make_pair(op_type, op_info));
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);

auto op = op_create_it->second(type, inputs, outputs, attrs);

auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
}

Expand Down Expand Up @@ -200,49 +217,32 @@ class OpRegistry {
return grad_op;
}

static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
}

static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}

private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_;
return op_info_map_;
}
};

class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels, have
// their corresponding registry and registrar. The action of registration is
// in the constructor of a global registrar variable, which, however, are not
// used in the code that calls package framework, and would be removed from
// the generated binary file by the linker. To avoid such removal, we add
// Touch to all registrar classes and make USE_OP macros to call this
// method. So, as long as the callee code calls USE_OP, the global
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};

template <typename OpType, typename ProtoMakerType>
template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar {
public:
explicit OpRegistrar(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};

template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
}
};

Expand All @@ -268,30 +268,20 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class) \
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we change the registration style, we need to discuss a previous problem. The user's code needs to write USE_OP(OP), USE_OP(OP_GRAD) twice. It is annoying.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just USE_OP(OP) is OK. because we always register forward operator and backward operator in same .cc/cu file.

USE_OP(OP) will force C++ linker to include all symbols in that .cc/cu file including forwarding operator and backwarding operator.

STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_registrar_##op_type##__(#op_type); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
grad_op_class> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}

/**
* Macro to register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)

/**
* Macro to register OperatorKernel.
Expand All @@ -307,14 +297,6 @@ class OpKernelRegistrar : public Registrar {
return 0; \
}

/**
* Macro to Forbid user register Gradient Operator.
*/
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")

#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)

Expand All @@ -333,23 +315,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type()

// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif

#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
Expand All @@ -369,18 +334,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif

#define USE_NO_GRAD_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)

#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_GRADIENT(op_type)
#define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);

#define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)

} // namespace framework
} // namespace paddle
Loading