Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge maps in OpRegistry and simplify register macros #3436

Merged
merged 12 commits into from
Aug 14, 2017
23 changes: 10 additions & 13 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker, rowwise_add_grad,
f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker, sigmoid_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker, add_grad, f::EmptyOp);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker,
many_output_op_grad, f::EmptyOp);

TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
Expand Down
14 changes: 11 additions & 3 deletions paddle/framework/grad_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std::vector<std::string>& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const OpProto& proto = *(OpRegistry::op_info_map().at(src_op->type_).proto_);
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();

Expand All @@ -76,8 +76,16 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}

OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
auto it = OpRegistry::op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type_);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type_);
it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
Expand Down
13 changes: 2 additions & 11 deletions paddle/framework/grad_op_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ USE_OP(add_two);
namespace paddle {
namespace framework {

class NOP : public OperatorBase {
public:
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};

class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Expand Down Expand Up @@ -61,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}

REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);

TEST(GradOpBuilder, MutiInOut) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 4, 5}},
Expand Down
191 changes: 80 additions & 111 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attribute.h"
Expand Down Expand Up @@ -174,59 +175,78 @@ Add a mark to which output is temporary is helpful for future optimization.
bool has_temporary_output_{false};
};

class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};

struct OpInfo {
std::function<OperatorBase*()> creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};

class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
*op_proto.mutable_type() = op_type;
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());

VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
"'%s' is registered more than once.", op_type);
OpInfo op_info;
op_info.creator_ = [] { return new OpType; };
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
*op_info.proto_->mutable_type() = op_type;
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
// ======will be refactored in following PRs============ //
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_info.proto_->inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_info.proto_->outputs()) {
varmap[var.name()] = idx++;
}
// ================================================ //
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
op_info_map().insert(std::make_pair(op_type, op_info));
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
}
}

template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(), "'%s' has not been registered.",
type);

auto op = op_create_it->second();
auto op = it->second.creator_();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;

op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
it->second.checker_->Check(op->attrs_);

GenerateTempVariableName(op);

Expand Down Expand Up @@ -268,14 +288,9 @@ class OpRegistry {
return grad_op;
}

static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
}

static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_;
return op_info_map_;
}

static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
Expand All @@ -284,17 +299,7 @@ class OpRegistry {
return maps_;
}

static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}

private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}

static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
Expand All @@ -320,19 +325,13 @@ class Registrar {
void Touch() {}
};

template <typename OpType, typename ProtoMakerType>
template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar {
public:
explicit OpRegistrar(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};

template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
}
};

Expand All @@ -358,30 +357,20 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class) \
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we change the registration style, we need to discuss a previous problem. The user's code needs to write USE_OP(OP), USE_OP(OP_GRAD) twice. It is annoying.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just USE_OP(OP) is OK. because we always register forward operator and backward operator in same .cc/cu file.

USE_OP(OP) will force C++ linker to include all symbols in that .cc/cu file including forwarding operator and backwarding operator.

STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_registrar_##op_type##__(#op_type); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
grad_op_class> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}

/**
* Macro to register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)

/**
* Macro to register OperatorKernel.
Expand All @@ -400,10 +389,12 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to Forbid user register Gradient Operator.
*/
/*
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
STATIC_ASSERT_GLOBAL_NAMESPACE( \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have a unified registrar REGISTER_OP_WITHOUT_GRADIENT. I think we does not need this one any more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean macro NO_GRADIENT(op_type) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. Is it redundant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
*/

#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
Expand All @@ -423,23 +414,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type()

// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif

#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
Expand All @@ -459,18 +433,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif

#define USE_NO_GRAD_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
#define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);

#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_GRADIENT(op_type)

#define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)

} // namespace framework
} // namespace paddle
8 changes: 4 additions & 4 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle

REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);

TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
Expand Down
7 changes: 7 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
};

class NOP : public OperatorBase {
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};

class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
Expand Down
Loading