From 50495d7b948c813222597fb9ffaf45cfcea29d2e Mon Sep 17 00:00:00 2001 From: Anirudh Date: Thu, 30 May 2019 15:13:47 -0700 Subject: [PATCH] NAG Optimizer with multi-precision support (#14568) * nag_mp * doc * reuse sgd updates where convenient --- python/mxnet/optimizer/optimizer.py | 58 ++++++--- src/operator/optimizer_op-inl.h | 163 +++++++++++++++++++++++- src/operator/optimizer_op.cc | 57 ++++++++- src/operator/optimizer_op.cu | 6 + tests/python/unittest/test_optimizer.py | 28 ++-- 5 files changed, 275 insertions(+), 37 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 613ae8985aca..c2c1aa6a76f4 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -28,7 +28,7 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, + signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, multi_mp_sgd_mom_update) from ..ndarray import sparse @@ -1029,7 +1029,7 @@ def update(self, index, weight, grad, state): @register class NAG(Optimizer): - """Nesterov accelerated SGD. + """Nesterov accelerated gradient. This optimizer updates each weight by:: @@ -1051,33 +1051,59 @@ def __init__(self, momentum=0.0, **kwargs): super(NAG, self).__init__(**kwargs) self.momentum = momentum + def create_state_multi_precision(self, index, weight): + weight_master_copy = None + if self.multi_precision and weight.dtype == numpy.float16: + weight_master_copy = weight.astype(numpy.float32) + return (self.create_state(index, weight_master_copy), weight_master_copy) + if weight.dtype == numpy.float16 and not self.multi_precision: + warnings.warn("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "NAG optimizer") + return self.create_state(index, weight) + def create_state(self, index, weight): momentum = None if self.momentum != 0.0: momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) return momentum - def update(self, index, weight, grad, state): + def _update_impl(self, index, weight, grad, state, multi_precision=False): assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) - grad = grad * self.rescale_grad - if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + kwargs = {'rescale_grad': self.rescale_grad} + if self.momentum > 0: + kwargs['momentum'] = self.momentum + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient - if state is not None: - mom = state - mom[:] *= self.momentum - mom[:] += grad - mom[:] += wd * weight - grad[:] += self.momentum * mom - weight[:] -= lr * grad + if not multi_precision: + if state is not None: + nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) + else: + sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) else: - assert self.momentum == 0.0 - weight[:] += -lr * (grad + wd * weight) + if state[0] is not None: + mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_sgd_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) + + def update(self, index, weight, grad, state): + self._update_impl(index, weight, grad, state, multi_precision=False) + + def update_multi_precision(self, index, weight, grad, state): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \ + and isinstance(state, (tuple, list)) + self._update_impl(index, weight, grad, state, + multi_precision=use_multi_precision) + @register class SGLD(Optimizer): @@ -1380,7 +1406,7 @@ def update(self, index, weight, grad, state): # preprocess grad grad *= self.rescale_grad if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + grad = clip(grad, - self.clip_gradient, self.clip_gradient) # accumulated g and delta initlization acc_g, acc_delta = state diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index bd923aebbb80..50637a8e7b42 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -140,6 +140,7 @@ struct MultiSGDMomParam : public dmlc::Parameter { } }; + template inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -639,7 +640,7 @@ inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs, } template -inline bool MP_SGD_InferType(const nnvm::NodeAttrs& attrs, +inline bool MP_InferType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; @@ -1003,6 +1004,166 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, } +struct NAGParam : public dmlc::Parameter { + float lr; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGMomParam : public dmlc::Parameter { + float lr; + float momentum; + float wd; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(NAGMomParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude " + "of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +struct NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, + const DType* weight_data, const DType* grad_data, + const DType param_clip_gradient, const DType param_momentum, + const DType param_lr, const DType param_wd, + const DType param_rescale_grad, const OpReqType req) { + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient) + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad*grad_data[i], + param_clip_gradient))); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i] + + (param_wd*weight_data[i]); + KERNEL_ASSIGN(out_data[i], req, weight_data[i] + - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*grad_data[i])); + } + } +}; + +template +inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, + static_cast(param.clip_gradient), + static_cast(param.momentum), static_cast(param.lr), + static_cast(param.wd), static_cast(param.rescale_grad), + req[0]); + }); +} + +struct MP_NAGMomKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + float* mom_data, const DType* weight_data, + const DType* grad_data, float* weight32, + const float param_clip_gradient, + const float param_momentum, const float param_lr, + const float param_wd, const float param_rescale_grad, + const OpReqType req) { + float w = weight32[i]; + if (param_clip_gradient >= 0.0f) { + mom_data[i] = param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), param_clip_gradient) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + mshadow_op::clip::Map(param_rescale_grad + *static_cast(grad_data[i]), + param_clip_gradient)); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } else { + mom_data[i] = param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i]) + + (param_wd*w); + w = w - param_lr*(param_momentum*mom_data[i] + + param_rescale_grad*static_cast(grad_data[i])); + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } + } +}; + +template +inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + NAGMomParam param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mom = inputs[2].FlatTo2D(s); + Tensor weight32 = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, + mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, + param.clip_gradient, param.momentum, param.lr, param.wd, + param.rescale_grad, req[0]); + }); +} + + struct FTMLParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 367b91b2646c..01410863640f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -35,6 +35,8 @@ DMLC_REGISTER_PARAMETER(MultiSGDParam); DMLC_REGISTER_PARAMETER(MultiSGDMomParam); DMLC_REGISTER_PARAMETER(FTMLParam); DMLC_REGISTER_PARAMETER(AdamParam); +DMLC_REGISTER_PARAMETER(NAGParam); +DMLC_REGISTER_PARAMETER(NAGMomParam); DMLC_REGISTER_PARAMETER(RMSPropParam); DMLC_REGISTER_PARAMETER(RMSPropAlexParam); DMLC_REGISTER_PARAMETER(FtrlParam); @@ -590,7 +592,7 @@ NNVM_REGISTER_OP(mp_sgd_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 3>) +.set_attr("FInferType", MP_InferType<2, 1, 3>) .set_attr("FCompute", MP_SGDUpdate) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { @@ -607,7 +609,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<4, 1>) -.set_attr("FInferType", MP_SGD_InferType<2, 1, 4>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) .set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; @@ -705,6 +707,57 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_arguments(AdamParam::__FIELDS__()); +NNVM_REGISTER_OP(nag_mom_update) +.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. +It updates the weights using the following formula, + +.. math:: + v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\ + W_t = W_{t-1} - v_t + +Where +:math:`\eta` is the learning rate of the optimizer +:math:`\gamma` is the decay rate of the momentum estimate +:math:`\v_t` is the update vector at time step `t` +:math:`\W_t` is the weight vector at time step `t` + +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2}; + }) +.set_attr("FCompute", NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_arguments(NAGMomParam::__FIELDS__()); + + +NNVM_REGISTER_OP(mp_nag_mom_update) +.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", MP_InferType<2, 1, 4>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", MP_NAGMomUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mom", "NDArray-or-Symbol", "Momentum") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_arguments(NAGMomParam::__FIELDS__()); + + NNVM_REGISTER_OP(rmsprop_update) .describe(R"code(Update function for `RMSProp` optimizer. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index c42cf1831c43..2c72462de016 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -251,6 +251,12 @@ NNVM_REGISTER_OP(multi_mp_sgd_update) NNVM_REGISTER_OP(multi_mp_sgd_mom_update) .set_attr("FCompute", MultiSGDMomUpdate); +NNVM_REGISTER_OP(nag_mom_update) +.set_attr("FCompute", NAGMomUpdate); + +NNVM_REGISTER_OP(mp_nag_mom_update) +.set_attr("FCompute", MP_NAGMomUpdate); + NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index d5aabcb4b1e5..3e6cdd0997ce 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -346,7 +346,7 @@ def create_state(self, index, weight): if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=np.float32) weight_master_copy = array(weight, ctx=weight.context, dtype=np.float32) - return (weight_master_copy, momentum) + return (momentum, weight_master_copy) else: if self.momentum != 0.0: momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) @@ -394,8 +394,8 @@ def update(self, index, weight, grad, state): grad32 = grad32 * self.rescale_grad if self.clip_gradient is not None: grad32 = mx.nd.clip(grad32, -self.clip_gradient, self.clip_gradient) - mom = state[1] - weight32 = state[0] + mom = state[0] + weight32 = state[1] if self.momentum == 0.0: weight32[:] += -lr * (grad32 + wd * weight32) else: @@ -417,23 +417,15 @@ def test_nag(): rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + for dtype in [np.float16, np.float32, np.float64]: - for mom_option in mom_options: - for cg_option in cg_options: - for rg_option in rg_options: - for wd_option in wd_options: - for mp_option in mp_options: - kwarg = {} - kwarg.update(mom_option) - kwarg.update(cg_option) - kwarg.update(rg_option) - kwarg.update(wd_option) - kwarg.update(mp_option) - if (dtype == np.float16 and - ('multi_precision' not in kwarg or + for params in itertools.product(mom_options, cg_options, rg_options, + wd_options, mp_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): - continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + continue + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-3, atol=1e-4) #SGLD class PySGLD(mx.optimizer.Optimizer):