-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[opt] Add regularation and Nesterov for mergerd_momentum op #37527
Conversation
Thanks for your contribution! |
"Attr(regularization_coeff) number must be equal " | ||
"to Input(Param) number.")); | ||
} | ||
VLOG(1) << use_nesterov << regularization_methods.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLOG(1) << use_nesterov << regularization_methods.size() | |
VLOG(5) << "use_nesterov: " << use_nesterov <<", regularization_methods.size(): " << regularization_methods.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tkx!
"to Input(Param) number.")); | ||
PADDLE_ENFORCE_EQ(n, regularization_coeffs.size(), | ||
platform::errors::InvalidArgument( | ||
"Attr(regularization_coeff) number must be equal " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size of Attr(regularization_coeff) must be equal to the size of Input(Param), but got the size of Attr(regularization_coeff) is %d, the size of Input(Param) is %d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to make the error message helpful, same for others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks!
@@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { | |||
.AsDispensable() | |||
.AsDuplicable(); | |||
AddAttr<float>("mu", "(float) Momentum coefficient"); | |||
AddAttr<bool>("use_nesterov", | |||
"(bool, default false) " | |||
"Use Nesterov Momentum") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Use Nesterov Momentum") | |
"Use Nesterov Momentum or not") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tkx!
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); | ||
} | ||
for (size_t idx = 0; idx < n; idx++) { | ||
std::string regularization_method = " "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::string regularization_method = " "; | |
std::string regularization_method = ""; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tkx!
|
||
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL | ||
params_out[idx]->data<T>(); | ||
velocitys_out[idx]->data<MPType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know what is the purpose to write these 2 lines? Just check whether params_out[idx]
and velocitys_out[idx]
is properly initialized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tks, this code has been deleted.
if (regularization_methods.size() != 0) { | ||
regularization_method = regularization_methods[idx]; | ||
} | ||
RegularizationType regularization_flag{RegularizationType::kNONE}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT. Not required to change. Maybe the following code would be simpler:
RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" ? RegularizationType::kL2DECAY : RegularizationType::kNONE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tks, this code has been modified according to the comments.
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems too many duplicate codes with momentum_op.h
. Maybe we can use a common function defined in momentum_op.h
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these codes have reused the DenseMomentumFunctor
function in momentum_op.h
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ddle#37527) * add regularation and Nesterov for mergerd_momentum * refine unittest for use_nesterov attr * refine op check * refine code * fix bug * refine code of regularization_flag * delete useless code
PR types
Performance optimization
PR changes
OPs
Describe
增强mergerd_momentum op功能,包括: