-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge maps in OpRegistry and simplify register macros #3436
Changes from 5 commits
ab08575
f784741
3e11e4c
6768b31
2ea2fbe
7a31d72
19dfe1f
fb6bec6
3e6e5c9
edb541f
39c986c
914a2f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <algorithm> | ||
#include <atomic> | ||
#include <type_traits> | ||
#include <typeinfo> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include "paddle/framework/attribute.h" | ||
|
@@ -174,59 +175,78 @@ Add a mark to which output is temporary is helpful for future optimization. | |
bool has_temporary_output_{false}; | ||
}; | ||
|
||
class NOPMaker : public OpProtoAndCheckerMaker { | ||
public: | ||
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) {} | ||
}; | ||
|
||
struct OpInfo { | ||
std::function<OperatorBase*()> creator_; | ||
std::string grad_op_type_; | ||
OpProto* proto_; | ||
OpAttrChecker* checker_; | ||
}; | ||
|
||
class OpRegistry { | ||
using OpCreator = std::function<OperatorBase*()>; | ||
using VarIndexMap = std::unordered_map<std::string, int>; | ||
using VarNameList = std::vector<std::string>; | ||
|
||
public: | ||
template <typename OpType, typename ProtoMakerType> | ||
static void RegisterOp(const std::string& op_type) { | ||
op_creators()[op_type] = [] { return new OpType; }; | ||
OpAttrChecker& op_checker = op_checkers()[op_type]; | ||
OpProto& op_proto = protos()[op_type]; | ||
auto maker = ProtoMakerType(&op_proto, &op_checker); | ||
maker.Validate(); | ||
*op_proto.mutable_type() = op_type; | ||
PADDLE_ENFORCE( | ||
op_proto.IsInitialized(), | ||
"Fail to initialize %s's OpProto, because %s is not initialized", | ||
op_type, op_proto.InitializationErrorString()); | ||
|
||
VarIndexMaps()[op_type].reset(new VarIndexMap()); | ||
auto& varmap = *VarIndexMaps()[op_type]; | ||
int idx = 0; | ||
for (auto& var : op_proto.inputs()) { | ||
varmap[var.name()] = idx++; | ||
template <typename OpType, typename ProtoMakerType, typename GradOpType> | ||
static void RegisterOp(const std::string& op_type, | ||
const std::string& grad_op_type) { | ||
PADDLE_ENFORCE(op_info_map().count(op_type) == 0, | ||
"'%s' is registered more than once.", op_type); | ||
OpInfo op_info; | ||
op_info.creator_ = [] { return new OpType; }; | ||
op_info.grad_op_type_ = grad_op_type; | ||
if (std::type_index(typeid(ProtoMakerType)) != | ||
std::type_index(typeid(NOPMaker))) { | ||
op_info.proto_ = new OpProto; | ||
op_info.checker_ = new OpAttrChecker; | ||
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); | ||
maker.Validate(); | ||
*op_info.proto_->mutable_type() = op_type; | ||
PADDLE_ENFORCE( | ||
op_info.proto_->IsInitialized(), | ||
"Fail to initialize %s's OpProto, because %s is not initialized", | ||
op_type, op_info.proto_->InitializationErrorString()); | ||
// ======will be refactored in following PRs============ // | ||
VarIndexMaps()[op_type].reset(new VarIndexMap()); | ||
auto& varmap = *VarIndexMaps()[op_type]; | ||
int idx = 0; | ||
for (auto& var : op_info.proto_->inputs()) { | ||
varmap[var.name()] = idx++; | ||
} | ||
idx = 0; | ||
for (auto& var : op_info.proto_->outputs()) { | ||
varmap[var.name()] = idx++; | ||
} | ||
// ================================================ // | ||
} | ||
idx = 0; | ||
for (auto& var : op_proto.outputs()) { | ||
varmap[var.name()] = idx++; | ||
op_info_map().insert(std::make_pair(op_type, op_info)); | ||
// register gradient op | ||
if (!grad_op_type.empty()) { | ||
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, ""); | ||
} | ||
} | ||
|
||
template <typename GradOpType> | ||
static void RegisterGradOp(const std::string& op_type, | ||
const std::string& grad_op_type) { | ||
op_creators()[grad_op_type] = [] { return new GradOpType; }; | ||
grad_ops()[op_type] = grad_op_type; | ||
} | ||
|
||
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, | ||
const VarNameList& inputs, | ||
const VarNameList& outputs, | ||
const AttributeMap& attrs) { | ||
auto op_create_it = op_creators().find(type); | ||
PADDLE_ENFORCE(op_create_it != op_creators().end(), | ||
"Operator %s cannot be found.", type); | ||
auto it = op_info_map().find(type); | ||
PADDLE_ENFORCE(it != op_info_map().end(), "'%s' has not been registered.", | ||
type); | ||
|
||
auto op = op_create_it->second(); | ||
auto op = it->second.creator_(); | ||
op->type_ = type; | ||
op->inputs_ = inputs; | ||
op->outputs_ = outputs; | ||
|
||
op->attrs_ = attrs; | ||
op_checkers().at(type).Check(op->attrs_); | ||
it->second.checker_->Check(op->attrs_); | ||
|
||
GenerateTempVariableName(op); | ||
|
||
|
@@ -268,14 +288,9 @@ class OpRegistry { | |
return grad_op; | ||
} | ||
|
||
static std::unordered_map<std::string, OpProto>& protos() { | ||
static std::unordered_map<std::string, OpProto> protos_; | ||
return protos_; | ||
} | ||
|
||
static std::unordered_map<std::string, std::string>& grad_ops() { | ||
static std::unordered_map<std::string, std::string> grad_ops_; | ||
return grad_ops_; | ||
static std::unordered_map<std::string, const OpInfo>& op_info_map() { | ||
static std::unordered_map<std::string, const OpInfo> op_info_map_; | ||
return op_info_map_; | ||
} | ||
|
||
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& | ||
|
@@ -284,17 +299,7 @@ class OpRegistry { | |
return maps_; | ||
} | ||
|
||
static std::unordered_map<std::string, OpCreator>& op_creators() { | ||
static std::unordered_map<std::string, OpCreator> op_creators_; | ||
return op_creators_; | ||
} | ||
|
||
private: | ||
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() { | ||
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; | ||
return op_checkers_; | ||
} | ||
|
||
static void GenerateTempVariableName(OperatorBase* op) { | ||
static std::atomic<size_t> gUniqId(0UL); | ||
for (auto& outname : op->outputs_) { | ||
|
@@ -320,19 +325,13 @@ class Registrar { | |
void Touch() {} | ||
}; | ||
|
||
template <typename OpType, typename ProtoMakerType> | ||
template <typename OpType, typename ProtoMakerType, typename GradOpType> | ||
class OpRegistrar : public Registrar { | ||
public: | ||
explicit OpRegistrar(const char* op_type) { | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); | ||
} | ||
}; | ||
|
||
template <typename GradOpType> | ||
class GradOpRegistrar : public Registrar { | ||
public: | ||
GradOpRegistrar(const char* op_type, const char* grad_op_type) { | ||
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type); | ||
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } | ||
OpRegistrar(const char* op_type, const char* grad_op_type) { | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type, | ||
grad_op_type); | ||
} | ||
}; | ||
|
||
|
@@ -358,30 +357,20 @@ class OpKernelRegistrar : public Registrar { | |
/** | ||
* Macro to register Operator. | ||
*/ | ||
#define REGISTER_OP(op_type, op_class, op_maker_class) \ | ||
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ | ||
grad_op_class) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ | ||
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ | ||
__op_registrar_##op_type##__(#op_type); \ | ||
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \ | ||
grad_op_class> \ | ||
__op_registrar_##op_type##__(#op_type, #grad_op_type); \ | ||
int TouchOpRegistrar_##op_type() { \ | ||
__op_registrar_##op_type##__.Touch(); \ | ||
return 0; \ | ||
} | ||
|
||
/** | ||
* Macro to register Gradient Operator. | ||
*/ | ||
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__reg_gradient_op__##op_type##_##grad_op_type, \ | ||
"REGISTER_GRADIENT_OP must be called in global namespace"); \ | ||
static ::paddle::framework::GradOpRegistrar<grad_op_class> \ | ||
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ | ||
#grad_op_type); \ | ||
int TouchOpGradientRegistrar_##op_type() { \ | ||
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \ | ||
return 0; \ | ||
} | ||
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ | ||
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) | ||
|
||
/** | ||
* Macro to register OperatorKernel. | ||
|
@@ -400,10 +389,12 @@ class OpKernelRegistrar : public Registrar { | |
/** | ||
* Macro to Forbid user register Gradient Operator. | ||
*/ | ||
/* | ||
#define NO_GRADIENT(op_type) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__reg_gradient_op__##op_type##_##op_type##_grad, \ | ||
"NO_GRADIENT must be called in global namespace") | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we already have a unified registrar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean macro There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. Is it redundant? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
__reg_gradient_op__##op_type##_##op_type##_grad, \ | ||
"NO_GRADIENT must be called in global namespace") | ||
*/ | ||
|
||
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ | ||
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) | ||
|
@@ -423,23 +414,6 @@ class OpKernelRegistrar : public Registrar { | |
static int use_op_itself_##op_type##_ __attribute__((unused)) = \ | ||
TouchOpRegistrar_##op_type() | ||
|
||
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use | ||
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't | ||
// be compiled. `NO_GRAD` should be removed after all gradient ops are | ||
// compeleted. | ||
#define NO_GRAD | ||
#ifndef NO_GRAD | ||
#define USE_OP_GRADIENT(op_type) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__use_op_gradient_##op_type, \ | ||
"USE_OP_GRADIENT must be called in global namespace"); \ | ||
extern int TouchOpGradientRegistrar_##op_type(); \ | ||
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \ | ||
TouchOpGradientRegistrar_##op_type() | ||
#else | ||
#define USE_OP_GRADIENT(op_type) | ||
#endif | ||
|
||
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ | ||
STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ | ||
|
@@ -459,18 +433,13 @@ class OpKernelRegistrar : public Registrar { | |
USE_OP_DEVICE_KERNEL(op_type, GPU) | ||
#endif | ||
|
||
#define USE_NO_GRAD_OP(op_type) \ | ||
USE_OP_ITSELF(op_type); \ | ||
USE_OP_KERNEL(op_type) | ||
#define USE_CPU_ONLY_OP(op_type) \ | ||
USE_OP_ITSELF(op_type); \ | ||
USE_OP_DEVICE_KERNEL(op_type, CPU); | ||
|
||
#define USE_CPU_OP(op_type) \ | ||
USE_OP_ITSELF(op_type); \ | ||
USE_OP_DEVICE_KERNEL(op_type, CPU); \ | ||
USE_OP_GRADIENT(op_type) | ||
|
||
#define USE_OP(op_type) \ | ||
USE_NO_GRAD_OP(op_type); \ | ||
USE_OP_GRADIENT(op_type) | ||
#define USE_OP(op_type) \ | ||
USE_OP_ITSELF(op_type); \ | ||
USE_OP_KERNEL(op_type) | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we change the registration style, we need to discuss a previous problem. The user's code needs to write
USE_OP(OP)
,USE_OP(OP_GRAD)
twice. It is annoying.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just
USE_OP(OP)
is OK. because we always register forward operator and backward operator in same.cc/cu
file.USE_OP(OP)
will force C++ linker to include all symbols in that.cc/cu
file including forwarding operator and backwarding operator.