Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt] Add regularation and Nesterov for mergerd_momentum op #37527

Merged
merged 7 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion paddle/fluid/operators/optimizers/merged_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
Expand All @@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Use Nesterov Momentum")
"Use Nesterov Momentum or not")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!

.SetDefault(false);
AddAttr<std::vector<std::string>>(
"regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault({});
AddAttr<std::vector<float>>("regularization_coeff",
"(float) regularization_coeff")
.SetDefault({});
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
Expand Down
110 changes: 102 additions & 8 deletions paddle/fluid/operators/optimizers/merged_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"

Expand Down Expand Up @@ -147,20 +148,45 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
master_params_out.clear();
}

auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(n, lrs.size(), platform::errors::InvalidArgument(
"Input(LearningRate) number must be "
"equal to Input(Param) number."));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(n, regularization_methods.size(),
platform::errors::InvalidArgument(
"Attr(regularization_method) number must be equal "
"to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"Attr(regularization_coeff) number must be equal "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of Attr(regularization_coeff) must be equal to the size of Input(Param), but got the size of Attr(regularization_coeff) is %d, the size of Input(Param) is %d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to make the error message helpful, same for others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tks!

"to Input(Param) number."));
}
VLOG(1) << use_nesterov << regularization_methods.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VLOG(1) << use_nesterov << regularization_methods.size()
VLOG(5) << "use_nesterov: " << use_nesterov <<", regularization_methods.size(): " << regularization_methods.size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!

<< regularization_coeffs.size();

using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
Expand All @@ -182,14 +208,82 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}

if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
for (size_t idx = 0; idx < n; idx++) {
std::string regularization_method = " ";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::string regularization_method = " ";
std::string regularization_method = "";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, tkx!

if (regularization_methods.size() != 0) {
regularization_method = regularization_methods[idx];
}
RegularizationType regularization_flag{RegularizationType::kNONE};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT. Not required to change. Maybe the following code would be simpler:

RegularizationType regularization_flag = regularization_methods.size() > 0 &&  regularization_methods[idx] == "l2_decay" ? RegularizationType::kL2DECAY : RegularizationType::kNONE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks, this code has been modified according to the comments.

if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
}
MPType regularization_coeff = static_cast<MPType>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];

#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
params_out[idx]->data<T>();
velocitys_out[idx]->data<MPType>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know what is the purpose to write these 2 lines? Just check whether params_out[idx] and velocitys_out[idx] is properly initialized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks, this code has been deleted.

const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
UseNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
NoNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems too many duplicate codes with momentum_op.h. Maybe we can use a common function defined in momentum_op.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these codes have reused the DenseMomentumFunctor function in momentum_op.h.

VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);

if (list == nullptr) {
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
Expand Down
197 changes: 197 additions & 0 deletions python/paddle/fluid/tests/unittests/test_merged_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,130 @@ def run_momentum_op(params,
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)


def run_momentum_op2(params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
mu=0.9,
rescale_grad=0.01,
use_merged=False,
use_nesterov=True):
assert len(params) == len(grads)
assert len(params) == len(velocitys)
if multi_precision:
assert len(params) == len(master_params)
op_type = 'merged_momentum' if use_merged else 'momentum'
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
helper = LayerHelper(op_type, **locals())

param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype) for p in params
]
grad_vars = [
helper.create_variable(
shape=g.shape, dtype=g.dtype) for g in grads
]
velocity_vars = [
helper.create_variable(
persistable=True, shape=v.shape, dtype=v.dtype)
for v in velocitys
]
lr_var = helper.create_variable(
persistable=True,
shape=learning_rate.shape,
dtype=learning_rate.dtype)

feed_dict = OrderedDict()

feed_dict.update(
OrderedDict([(p_var.name, p_val)
for p_var, p_val in zip(param_vars, params)]))
feed_dict.update(
OrderedDict([(v_var.name, v_val)
for v_var, v_val in zip(velocity_vars, velocitys)]))
fetch_list = list(feed_dict.keys())

feed_dict.update(
OrderedDict([(g_var.name, g_val)
for g_var, g_val in zip(grad_vars, grads)]))
feed_dict.update({lr_var.name: learning_rate})

if multi_precision:
master_param_vars = [
helper.create_variable(
persistable=True, shape=p.shape, dtype=p.dtype)
for p in master_params
]
feed_dict.update(
OrderedDict([(mp_var.name, mp_val)
for mp_var, mp_val in zip(master_param_vars,
master_params)]))
# CPUPlace does not use MasterParam
if isinstance(place, paddle.CUDAPlace):
fetch_list = fetch_list + [
mp_var.name for mp_var in master_param_vars
]
else:
master_param_vars = None

if not use_merged:
for i, (p, g,
v) in enumerate(zip(param_vars, grad_vars, velocity_vars)):
inputs = {
'Param': p,
'Grad': g,
'Velocity': v,
'LearningRate': lr_var,
}
outputs = {'ParamOut': p, 'VelocityOut': v}
if multi_precision:
inputs['MasterParam'] = master_param_vars[i]
outputs['MasterParamOut'] = master_param_vars[i]
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method': 'l2_decay',
'regularization_coeff': 2.0,
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)
else:
inputs = {
'Param': param_vars,
'Grad': grad_vars,
'Velocity': velocity_vars,
'LearningRate': lr_var,
}
outputs = {'ParamOut': param_vars, 'VelocityOut': velocity_vars}
if multi_precision:
inputs['MasterParam'] = master_param_vars
outputs['MasterParamOut'] = master_param_vars
attrs = {
'mu': mu,
'multi_precision': multi_precision,
'rescale_grad': rescale_grad,
'use_nesterov': use_nesterov,
'regularization_method':
['l2_decay' for i in range(len(param_vars))],
'regularization_coeff': [2.0 for i in range(len(param_vars))],
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs)

exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
return exe.run(main, feed=feed_dict, fetch_list=fetch_list)


class TestMergedMomentum(unittest.TestCase):
def setUp(self):
paddle.enable_static()
Expand Down Expand Up @@ -193,5 +317,78 @@ def test_main(self):
self.check_with_place(place, multi_precision)


class TestMergedMomentum2(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10

def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]

def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float16 if multi_precision and isinstance(
place, paddle.CUDAPlace) else np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
velocitys = self.gen_rand_data(shapes, mp_dtype)
learning_rate = self.gen_rand_data([[1]], mp_dtype)[0]
if multi_precision:
master_params = [p.astype(mp_dtype) for p in params]
else:
master_params = None
return params, grads, velocitys, master_params, learning_rate

def check_with_place(self, place, multi_precision):
params, grads, velocitys, master_params, learning_rate = self.prepare_data(
self.shapes, multi_precision, self.seed, place)

def run_op(use_nesterov, use_merged):
# FIXME(zengjinle): CPU Momentum Op does not support rescale_grad
rescale_grad = 1.0 if isinstance(place, paddle.CPUPlace) else 0.01
return run_momentum_op2(
params,
grads,
velocitys,
master_params,
learning_rate,
place,
multi_precision,
rescale_grad=rescale_grad,
use_merged=use_merged,
use_nesterov=use_nesterov)

outs1 = run_op(use_nesterov=True, use_merged=True)
outs2 = run_op(use_nesterov=True, use_merged=False)
self.assertEqual(len(outs1), len(outs2))
for i, (out1, out2) in enumerate(zip(outs1, outs2)):
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out1, out2))
else:
self.assertTrue(np.allclose(out1, out2, atol=1e-7))

outs3 = run_op(use_nesterov=False, use_merged=True)
outs4 = run_op(use_nesterov=False, use_merged=False)
self.assertEqual(len(outs3), len(outs4))
for j, (out3, out4) in enumerate(zip(outs3, outs4)):
if isinstance(place, paddle.CUDAPlace):
self.assertTrue(np.array_equal(out3, out4))
else:
self.assertTrue(np.allclose(out3, out4, atol=1e-7))

def get_places(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places

def test_main(self):
for multi_precision in [False, True]:
for place in self.get_places():
self.check_with_place(place, multi_precision)


if __name__ == "__main__":
unittest.main()