-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St]Get grad names when call append backward to fix high order gradient #53250
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for backward.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm for test_dropout_op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for atol
…dient (PaddlePaddle#53250) [Dy2St]Get grad names when call append backward to fix high order gradient (PaddlePaddle#53250)
…dient (PaddlePaddle#53250) [Dy2St]Get grad names when call append backward to fix high order gradient (PaddlePaddle#53250)
PR types
Bug fixes
PR changes
Others
Description
PCard-66972
1 背景
高阶情况下动转静的input、param、out对应的grad var name不再是简单的x@GRAD,他们的name可能的形式可能有x@GRAD@GRAD、grad/grad/x@GRAD、x@GRAD_0等。所以之前通过 x.name + '@Grad' 进行拼接的方式无法得到正确的grad var name,所以需要对动转静获取grad var name的模块进行升级以支持高阶的情况
2 本PR的修改
之前是通过遍历program中的var,正则匹配出grad var name,但是这种方式在一些情况下是错误的且难以维护。
在本PR中对获取grad var name的方式进行了升级,通过append_backward返回的grad_info_map来拿到对应的grad var,从而拿到grad var name,只要静态图下append_backward的逻辑是正确的那么就一定可以拿到正确的grad var name。
关于
calc_gradient
的修改,PR中只是将calc_gradient
的主要逻辑抽离出来,形成一个calc_gradient_helper
函数,提供给动转静模块使用