-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[opt] Add regularation and Nesterov for mergerd_momentum op #37527
Changes from 3 commits
d07149f
f27e863
e27ce50
ea8cb54
c049ae0
cf953a3
46e67d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,6 +18,7 @@ | |||||
#include "paddle/fluid/framework/operator.h" | ||||||
#include "paddle/fluid/framework/tensor.h" | ||||||
#include "paddle/fluid/operators/amp/fp16_type_traits.h" | ||||||
#include "paddle/fluid/operators/optimizers/momentum_op.h" | ||||||
#include "paddle/fluid/platform/for_range.h" | ||||||
#include "paddle/fluid/platform/macros.h" | ||||||
|
||||||
|
@@ -147,20 +148,45 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { | |||||
master_params_out.clear(); | ||||||
} | ||||||
|
||||||
auto lr = ctx.Input<framework::Tensor>("LearningRate"); | ||||||
auto mu = ctx.Attr<float>("mu"); | ||||||
auto rescale_grad = ctx.Attr<float>("rescale_grad"); | ||||||
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate"); | ||||||
if (lrs.size() != 1) { | ||||||
PADDLE_ENFORCE_EQ(n, lrs.size(), platform::errors::InvalidArgument( | ||||||
"Input(LearningRate) number must be " | ||||||
"equal to Input(Param) number.")); | ||||||
} | ||||||
auto use_nesterov = ctx.Attr<bool>("use_nesterov"); | ||||||
auto regularization_methods = | ||||||
ctx.Attr<std::vector<std::string>>("regularization_method"); | ||||||
auto regularization_coeffs = | ||||||
ctx.Attr<std::vector<float>>("regularization_coeff"); | ||||||
if (regularization_methods.size() != 0) { | ||||||
PADDLE_ENFORCE_EQ(n, regularization_methods.size(), | ||||||
platform::errors::InvalidArgument( | ||||||
"Attr(regularization_method) number must be equal " | ||||||
"to Input(Param) number.")); | ||||||
PADDLE_ENFORCE_EQ(n, regularization_coeffs.size(), | ||||||
platform::errors::InvalidArgument( | ||||||
"Attr(regularization_coeff) number must be equal " | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The size of Attr(regularization_coeff) must be equal to the size of Input(Param), but got the size of Attr(regularization_coeff) is %d, the size of Input(Param) is %d There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to make the error message helpful, same for others. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, tks! |
||||||
"to Input(Param) number.")); | ||||||
} | ||||||
VLOG(1) << use_nesterov << regularization_methods.size() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, tkx! |
||||||
<< regularization_coeffs.size(); | ||||||
|
||||||
using MPType = typename operators::details::MPTypeTrait<T>::Type; | ||||||
|
||||||
auto &dev_ctx = ctx.template device_context<DeviceContext>(); | ||||||
|
||||||
if (lrs.size() == 1 && use_nesterov == false && | ||||||
regularization_methods.size() == 0) { | ||||||
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ | ||||||
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \ | ||||||
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ | ||||||
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ | ||||||
kernel_params.mu = static_cast<MPType>(mu); \ | ||||||
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \ | ||||||
kernel_params.lr = lr->data<MPType>(); \ | ||||||
kernel_params.lr = lrs[0]->data<MPType>(); \ | ||||||
for (size_t i = 0; i < kernel_num; ++i) { \ | ||||||
size_t start = i * kMaxMergedNum; \ | ||||||
size_t end = std::min((i + 1) * kMaxMergedNum, n); \ | ||||||
|
@@ -182,14 +208,82 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { | |||||
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ | ||||||
<< kernel_params.param_num; \ | ||||||
} | ||||||
|
||||||
if (multi_precision) { | ||||||
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); | ||||||
if (multi_precision) { | ||||||
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); | ||||||
} else { | ||||||
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); | ||||||
} | ||||||
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL | ||||||
} else { | ||||||
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); | ||||||
} | ||||||
for (size_t idx = 0; idx < n; idx++) { | ||||||
std::string regularization_method = " "; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, tkx! |
||||||
if (regularization_methods.size() != 0) { | ||||||
regularization_method = regularization_methods[idx]; | ||||||
} | ||||||
RegularizationType regularization_flag{RegularizationType::kNONE}; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT. Not required to change. Maybe the following code would be simpler: RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" ? RegularizationType::kL2DECAY : RegularizationType::kNONE. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tks, this code has been modified according to the comments. |
||||||
if (regularization_method == "l2_decay") { | ||||||
regularization_flag = RegularizationType::kL2DECAY; | ||||||
} | ||||||
MPType regularization_coeff = static_cast<MPType>(0.0); | ||||||
if (regularization_coeffs.size() != 0) { | ||||||
regularization_coeff = | ||||||
static_cast<MPType>(regularization_coeffs[idx]); | ||||||
} | ||||||
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; | ||||||
|
||||||
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL | ||||||
params_out[idx]->data<T>(); | ||||||
velocitys_out[idx]->data<MPType>(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know what is the purpose to write these 2 lines? Just check whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tks, this code has been deleted. |
||||||
const MPType *master_in_data = | ||||||
multi_precision ? master_params[idx]->data<MPType>() : nullptr; | ||||||
MPType *master_out_data = | ||||||
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr; | ||||||
if (platform::is_cpu_place(ctx.GetPlace())) { | ||||||
CPUDenseMomentumFunctor<MPType> functor; | ||||||
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu, | ||||||
use_nesterov, regularization_flag, regularization_coeff, | ||||||
params_out[idx], velocitys_out[idx]); | ||||||
VLOG(10) << "Launch MergedMomentum cpu kernel."; | ||||||
} else if (platform::is_gpu_place(ctx.GetPlace())) { | ||||||
platform::ForRange<DeviceContext> for_range( | ||||||
static_cast<const DeviceContext &>(ctx.device_context()), | ||||||
params[idx]->numel()); | ||||||
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ | ||||||
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \ | ||||||
params[idx]->data<T>(), grads[idx]->data<T>(), \ | ||||||
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \ | ||||||
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \ | ||||||
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \ | ||||||
master_out_data); \ | ||||||
for_range(functor); | ||||||
if (use_nesterov) { | ||||||
if (regularization_flag == RegularizationType::kL2DECAY) { | ||||||
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( | ||||||
UseNesterov, RegularizationType::kL2DECAY); | ||||||
VLOG(10) | ||||||
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; | ||||||
} else { | ||||||
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov, | ||||||
RegularizationType::kNONE); | ||||||
VLOG(10) | ||||||
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE."; | ||||||
} | ||||||
} else { | ||||||
if (regularization_flag == RegularizationType::kL2DECAY) { | ||||||
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( | ||||||
NoNesterov, RegularizationType::kL2DECAY); | ||||||
VLOG(10) | ||||||
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; | ||||||
} else { | ||||||
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov, | ||||||
RegularizationType::kNONE); | ||||||
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems too many duplicate codes with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these codes have reused the |
||||||
VLOG(10) | ||||||
<< "Launch MergedMomentum kernel with multi_lr and regularization."; | ||||||
} | ||||||
} | ||||||
}; | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tkx!