-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SOT][Faster Guard] add some basic guard #69313
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
paddle/fluid/pybind/jit.cc
Outdated
py::class_<GuardBase, std::shared_ptr<GuardBase>>( | ||
*m, "GuardBase", R"DOC(GuardBase Class.)DOC") | ||
.def("check", &GuardBase::check_pybind); | ||
py::class_<LambdaGuard, GuardBase, std::shared_ptr<LambdaGuard>>( | ||
*m, "LambdaGuard", R"DOC(LambdaGuard Class.)DOC") | ||
.def(py::init<const py::function &>(), py::arg("guard_check_fn")); | ||
py::class_<GuardGroup, GuardBase, std::shared_ptr<GuardGroup>>( | ||
*m, "GuardGroup", R"DOC(GuardGroup Class.)DOC") | ||
.def(py::init<std::vector<std::shared_ptr<GuardBase>>>(), | ||
py::arg("guards")); | ||
py::class_<TypeMatchGuard, GuardBase, std::shared_ptr<TypeMatchGuard>>( | ||
*m, "TypeMatchGuard", R"DOC(TypeMatchGuard Class.)DOC") | ||
.def(py::init<const py::type &>(), py::arg("py_type")); | ||
py::class_<LengthMatchGuard, GuardBase, std::shared_ptr<LengthMatchGuard>>( | ||
*m, "LengthMatchGuard", R"DOC(LengthMatchGuard Class.)DOC") | ||
.def(py::init<Py_ssize_t>(), py::arg("length")); | ||
py::class_<ValueMatchGuard, GuardBase, std::shared_ptr<ValueMatchGuard>>( | ||
*m, "ValueMatchGuard", R"DOC(ValueMatchGuard Class.)DOC") | ||
.def(py::init<const py::object &>(), py::arg("py_value")); | ||
py::class_<DtypeMatchGuard, GuardBase, std::shared_ptr<DtypeMatchGuard>>( | ||
*m, "DtypeMatchGuard", R"DOC(DtypeMatchGuard Class.)DOC") | ||
.def(py::init<const paddle::framework::proto::VarType &>(), | ||
py::arg("dtype")) | ||
.def(py::init<const phi::DataType &>(), py::arg("dtype")); | ||
py::class_<LayerMatchGuard, GuardBase, std::shared_ptr<LayerMatchGuard>>( | ||
*m, "LayerMatchGuard", R"DOC(LayerMatchGuard Class.)DOC") | ||
.def(py::init<const py::object &>(), py::arg("layer_obj")); | ||
|
||
m->def( | ||
"merge_guard", | ||
[](const std::vector<std::shared_ptr<GuardBase>> &py_guards) { | ||
return GuardGroup(py_guards); | ||
}, | ||
py::arg("py_guards")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块可以抽一个函数叫 BindGuard
,这里直接调用下就好了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -61,6 +61,43 @@ void BindJit(pybind11::module *m) { | |||
}); | |||
} | |||
|
|||
void BindGuard(pybind11::module *m) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块也要放到 SOT_IS_SUPPORTED
里,不然编译 Python 3.14 GuardGroup
等会找不到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
相关 PR 链接下 #69264 |
PR Category
Execute Infrastructure
PR Types
Performance
Description
在C++端添加了一些基础的 guard并添加了一些单元测试