Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored GradNodeAccumulation data structure and behaviour #39526

Merged
merged 15 commits into from
Feb 24, 2022

Conversation

jim19930609
Copy link
Contributor

@jim19930609 jim19930609 commented Feb 14, 2022

PR types

New features

PR changes

Others

Describe

We used to register RetainGrad hook for leaf tensor so as to update grad tensor's value. After this patch, we asked GradNodeAccumulation to track a weak_ptr of grad_tensor, the value of which will get updated upon calling GradNodeAccumulation::operator(), instead of having to register a RetainGrad hook.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@jim19930609 jim19930609 force-pushed the fix_accumulation_node branch from 352c972 to aab36af Compare February 16, 2022 07:44
@jim19930609 jim19930609 force-pushed the fix_accumulation_node branch from 490a1dc to 200806d Compare February 20, 2022 05:21
@jim19930609 jim19930609 force-pushed the fix_accumulation_node branch from 051f10a to d342316 Compare February 21, 2022 07:33
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错误检,细节后续可以完善下

std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
grad = accumulation_grad_node->Grad();
grad = egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE(grad != nullptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推荐直接用PADDLE_ENFORCE_NOT_NULL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多谢!

grad = egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE(grad != nullptr,
paddle::platform::errors::Fatal(
"Detected NULL grad"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detected NULL grad后面需要标点符号断下句吗?报错结尾建议用句点结束

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -52,9 +52,15 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
}
}

void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
// TODO(jiabin): Support More Tensor type here
static void RetainGradForRegularNode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge this with RegisterGradientHookForTensor

@JiabinYang JiabinYang merged commit 1abfc8d into PaddlePaddle:develop Feb 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants