-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactored GradNodeAccumulation data structure and behaviour #39526
Refactored GradNodeAccumulation data structure and behaviour #39526
Conversation
Thanks for your contribution! |
352c972
to
aab36af
Compare
… fix_accumulation_node
… fix_accumulation_node
490a1dc
to
200806d
Compare
… fix_accumulation_node
… fix_accumulation_node
051f10a
to
d342316
Compare
… fix_accumulation_node
… fix_accumulation_node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错误检,细节后续可以完善下
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node); | ||
grad = accumulation_grad_node->Grad(); | ||
grad = egr::EagerUtils::mutable_grad(self->tensor); | ||
PADDLE_ENFORCE(grad != nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
推荐直接用PADDLE_ENFORCE_NOT_NULL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多谢!
grad = egr::EagerUtils::mutable_grad(self->tensor); | ||
PADDLE_ENFORCE(grad != nullptr, | ||
paddle::platform::errors::Fatal( | ||
"Detected NULL grad" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detected NULL grad后面需要标点符号断下句吗?报错结尾建议用句点结束
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -52,9 +52,15 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor, | |||
} | |||
} | |||
|
|||
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) { | |||
// TODO(jiabin): Support More Tensor type here | |||
static void RetainGradForRegularNode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge this with RegisterGradientHookForTensor
PR types
New features
PR changes
Others
Describe
We used to register RetainGrad hook for leaf tensor so as to update grad tensor's value. After this patch, we asked GradNodeAccumulation to track a weak_ptr of grad_tensor, the value of which will get updated upon calling GradNodeAccumulation::operator(), instead of having to register a RetainGrad hook.