-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix inaccurate return of low precision op list #49391
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
f21b5ee
to
f6faa79
Compare
@@ -1189,6 +1189,7 @@ def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False): | |||
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( | |||
{code_indent} "{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); | |||
{code_indent} const auto& kernel = kernel_result.kernel; | |||
{code_indent} phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉可以把 if 判断条件加在这一层以进一步减少额外的调用开销
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -221,6 +221,7 @@ def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): | |||
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( | |||
"{kernel_name}", {{kernel_backend, kernel_layout, kernel_data_type}}); | |||
const auto& phi_kernel = kernel_result.kernel; | |||
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -210,6 +210,7 @@ def gen_string_tensor_kernel_code(self, inplace_flag=False, code_indent=""): | |||
VLOG(6) << "{self.api} api strings kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; | |||
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( | |||
"{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}); | |||
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{self.api}", kernel_data_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
f6faa79
to
9a5fc61
Compare
9a5fc61
to
6bf8c20
Compare
TODO:
(2)不需要环境变量通过1、2这种等级控制,默认打印出来模型中所有算子的列表。在AMP任务中,一些算子可能是FP16或者FP32 Kernel都会调用 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
Others
Describe
修复低精度算子列表获取不准确的问题。
说明:
整个功能通过FLAGS_low_precision_op_list环境变量控制,环境变量默认为0
在模型运行结束调用: paddle.amp.low_precision_op_list()
用法:
FLAGS_low_precision_op_list=1:返回当前模型前向低精度算子列表,便于低精度训练加入黑名单,注意对于inplace的OP无法通过加入黑名单的方式接触低精度运算。
FLAGS_low_precision_op_list=2:返回当前模型前向算子列表,便于查看算子列表。