-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add bf16 for some ops in static mode #51582
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/nn/clip.py
Outdated
@@ -733,7 +733,7 @@ def _static_clip(self, params_grads): | |||
merge_grad = merge_selected_rows(g) | |||
merge_grad = get_tensor_from_selected_rows(merge_grad) | |||
sum_square = _squared_l2_norm(merge_grad) | |||
if sum_square.dtype == core.VarDesc.VarType.FP16: | |||
if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16和bf16的结果还是分开存储吧,且需要检查,fp16和bf16不能同时存在。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
develop branch里面把这个给改掉了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541. * Cherry-pick the master_grad support of adamw, pull request #51141. * add bf16 for some ops in static mode (#51582) * Add bfloat16 support for some api in static mode. * Fix codestyle. * Revert the change of layer_function_generator.py. --------- Co-authored-by: Shaojie WANG <wsjmessi@163.com>
PR types
Others
PR changes
OPs
Describe
add bf16 datatype for some ops in static mode