-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FLAGS_low_precision_op_list to get amp list of current module #48843
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
需要加一个开关,只在调试模式下才记录低精度算子 -> 开关已加 |
64a8e99
to
b8743d3
Compare
paddle/fluid/pybind/pybind.cc
Outdated
@@ -2543,6 +2543,10 @@ All parameter, weight, gradient are variables in Paddle. | |||
m.def("update_autotune_status", | |||
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); | |||
|
|||
m.def("get_low_pricision_op_list", [] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pricision -> precision,下面还有几处拼写错误。下个PR改吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
# amp list elementwise_add, cast | ||
with paddle.amp.auto_cast(): | ||
c = a + b | ||
paddle.fluid.dygraph.amp.auto_cast.low_pricision_op_list() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测可以加上对结果正确性的判断,避免后续将功能改坏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
2581df9
67c6aef
to
2581df9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
新增AMP lIst成员,用于记录当前模型中运行的amp算子
添加 export FLAGS_low_precision_op_list=1 环境变量控制低精度OP列表的统计,当 FLAGS_low_precision_op_list=1 时开始统计当前模型中使用amp计算的算子,否则不统计.(说明:当前FLAG在后续开发完正式接口后会删除)
用法
1.开启环境变量:FLAGS_low_precision_op_list=1
2.在模型最后加入: paddle.fluid.dygraph.amp.auto_cast.low_pricision_op_list()
将会打印对应的低精度OP列表