From 3e9711356ed9244a7e6c94f9daaa67c691a392ab Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 25 Dec 2018 07:05:12 +0000 Subject: [PATCH 1/3] tests --- python/mxnet/optimizer/optimizer.py | 68 ++++++++++++++++++- src/operator/optimizer_op-inl.h | 30 ++++++--- src/operator/optimizer_op.cc | 45 +++++++++++-- src/operator/optimizer_op.cu | 5 +- tests/python/unittest/test_optimizer.py | 89 ++++++++++++++++++++++++- 5 files changed, 220 insertions(+), 17 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index a085b6fe2ef6..2adac5bad965 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -27,14 +27,14 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update) + signsgd_update, signum_update, adamw_update) from ..ndarray import sparse from ..random import normal __all__ = [ 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', - 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' + 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register', 'AdamW' ] @@ -1018,6 +1018,70 @@ class ccSGD(SGD): def __init__(self, *args, **kwargs): super(ccSGD, self).__init__(*args, **kwargs) +@register +class AdamW(Optimizer): + """The Adam optimizer with fixed weight decay regularization. + + This class implements the optimizer described in *Fixing Weight Decay + Regularization in Adam*, available at https://arxiv.org/abs/1711.05101. + + Note that this is different from the original Adam optimizer which adds L2 + regularization on the weights to the loss: it regularizes weights with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + Updates are applied by:: + + rescaled_grad = clip(grad * rescale_grad, clip_gradient) + m = beta1 * m + (1 - beta1) * rescaled_grad + v = beta2 * v + (1 - beta2) * (rescaled_grad**2) + w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w) + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + For details of the update algorithm, see :class:`~mxnet.ndarray.adamw_update`. + + Parameters + ---------- + beta1 : float, optional + Exponential decay rate for the first moment estimates. + beta2 : float, optional + Exponential decay rate for the second moment estimates. + epsilon : float, optional + Small value to avoid division by 0. + """ + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + **kwargs): + super(AdamW, self).__init__(learning_rate=learning_rate, **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + + def create_state(self, index, weight): + return (zeros(weight.shape, weight.context, dtype=weight.dtype), #mean + zeros(weight.shape, weight.context, dtype=weight.dtype)) #variance + + def update(self, index, weight, grad, state): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + + t = self._index_update_count[index] + coef1 = 1. - self.beta1**t + coef2 = 1. - self.beta2**t + lr *= math.sqrt(coef2)/coef1 + + kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, + 'rescale_grad': self.rescale_grad} + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + + mean, var = state + adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs) + @register class Adam(Optimizer): """The Adam optimizer. diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 9251b8614806..46b477520bb5 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -837,7 +837,10 @@ struct AdamParam : public dmlc::Parameter { } }; -template +/* + * \brief adam and adam_w update. Set decoupled=True for adam_w. + */ +template inline void AdamUpdate(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -855,9 +858,12 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, Tensor var = inputs[3].FlatTo2D(s); Tensor out = outputs[0].FlatTo2D(s); - grad = scalar(param.rescale_grad) * grad + - scalar(param.wd) * weight; - + if (decoupled) { + grad = scalar(param.rescale_grad) * grad; + } else { + grad = scalar(param.rescale_grad) * grad + + scalar(param.wd) * weight; + } if (param.clip_gradient >= 0.0f) { mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * F(grad, DType(param.clip_gradient)); @@ -867,10 +873,18 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); } - Assign(out, req[0], - weight - - scalar(param.lr) * mean / - (F(var) + scalar(param.epsilon))); + if (decoupled) { + Assign(out, req[0], + weight - + scalar(param.lr) * (mean / + (F(var) + scalar(param.epsilon)) + + (scalar(param.wd) * weight))); + } else { + Assign(out, req[0], + weight - + scalar(param.lr) * mean / + (F(var) + scalar(param.epsilon))); + } }); } diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 6c44f99c1443..063adaed8697 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -472,15 +472,16 @@ are 1st and 2nd order moment estimates (mean and variance). .. math:: - g_t = \nabla J(W_{t-1})\\ + g_t = \nabla J(W_{t-1}) + wd W_{t-1}\\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ W_t = W_{t-1} - \alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } It updates the weights using:: - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) + g = grad + wd*w + m = beta1*m + (1-beta1)*g + v = beta2*v + (1-beta2)*(g**2) w += - learning_rate * m / (sqrt(v) + epsilon) However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage @@ -507,7 +508,7 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; }) -.set_attr("FCompute", AdamUpdate) +.set_attr("FCompute", AdamUpdate) .set_attr("FComputeEx", AdamUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") @@ -515,6 +516,42 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_arguments(AdamParam::__FIELDS__()); +NNVM_REGISTER_OP(adamw_update) +.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of +Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. + +Adam update consists of the following steps, where g represents gradient and m, v +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \alpha (\frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w += - learning_rate * (m / (sqrt(v) + epsilon) + w*wd) + +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", AdamUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_arguments(AdamParam::__FIELDS__()); NNVM_REGISTER_OP(rmsprop_update) .describe(R"code(Update function for `RMSProp` optimizer. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 0fd2ca83fda4..8de746a48462 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -246,9 +246,12 @@ NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); NNVM_REGISTER_OP(adam_update) -.set_attr("FCompute", AdamUpdate) +.set_attr("FCompute", AdamUpdate) .set_attr("FComputeEx", AdamUpdateEx); +NNVM_REGISTER_OP(adamw_update) +.set_attr("FCompute", AdamUpdate); + NNVM_REGISTER_OP(rmsprop_update) .set_attr("FCompute", RMSPropUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index acf24ee1b794..d5f680c84fa0 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -506,12 +506,11 @@ def test_ftml(): class PyAdam(mx.optimizer.Optimizer): """python reference implemenation of adam""" def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - decay_factor=(1 - 1e-8), lazy_update=True, **kwargs): + lazy_update=True, **kwargs): super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon - self.decay_factor = decay_factor self.lazy_update = lazy_update def create_state(self, index, weight): @@ -614,6 +613,92 @@ def test_adam(): dtype, w_stype='default', g_stype='row_sparse', rtol=1e-4, atol=2e-5) +# ADAMW +class PyAdamW(mx.optimizer.Optimizer): + """python reference implemenation of AdamW""" + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + **kwargs): + super(PyAdamW, self).__init__(learning_rate=learning_rate, **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + + def create_state(self, index, weight): + """Create additional optimizer state: mean, variance + + Parameters + ---------- + weight : NDArray + The weight data + + """ + return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean + mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + + def update(self, index, weight, grad, state): + """Update the parameters. + + Parameters + ---------- + index : int + An unique integer key used to index the parameters + + weight : NDArray + weight ndarray + + grad : NDArray + grad ndarray + + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. + """ + lr = self._get_lr(index) + self._update_count(index) + + t = self._index_update_count[index] + mean, variance = state + + wd = self._get_wd(index) + coef1 = 1. - self.beta1**t + coef2 = 1. - self.beta2**t + lr *= math.sqrt(coef2)/coef1 + + grad *= self.rescale_grad + # clip gradients + if self.clip_gradient is not None: + mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad) + # update mean + mean *= self.beta1 + mean += grad * (1. - self.beta1) + # update variance + variance *= self.beta2 + variance += (1 - self.beta2) * mx.nd.square(grad, out=grad) + # update weight + weight -= lr * (mean/(mx.nd.sqrt(variance) + self.epsilon) + wd * weight) + +@with_seed() +def test_adamw(): + opt1 = PyAdamW + opt2 = mx.optimizer.AdamW + shape = (3, 4, 5) + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] + for dtype in [np.float16, np.float32, np.float64]: + for cg_option in cg_options: + for rg_option in rg_options: + for wd_option in wd_options: + for mp_option in mp_options: + kwarg = {} + kwarg.update(cg_option) + kwarg.update(rg_option) + kwarg.update(wd_option) + kwarg.update(mp_option) + if (dtype == np.float16 and + ('multi_precision' not in kwarg or not kwarg['multi_precision'])): + continue + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) # AdaMax class PyAdamax(mx.optimizer.Optimizer): From d091c5997342cba4c71bb14ef79bd0410f85061d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 27 Dec 2018 23:04:47 +0000 Subject: [PATCH 2/3] remove optimizer and move op to contrib --- python/mxnet/optimizer/optimizer.py | 68 +------------ src/operator/contrib/adamw-inl.h | 125 ++++++++++++++++++++++++ src/operator/contrib/adamw.cc | 71 ++++++++++++++ src/operator/contrib/adamw.cu | 35 +++++++ src/operator/optimizer_op.cc | 45 +-------- src/operator/optimizer_op.cu | 5 +- tests/python/unittest/test_optimizer.py | 87 ----------------- 7 files changed, 238 insertions(+), 198 deletions(-) create mode 100644 src/operator/contrib/adamw-inl.h create mode 100644 src/operator/contrib/adamw.cc create mode 100644 src/operator/contrib/adamw.cu diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 2adac5bad965..a085b6fe2ef6 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -27,14 +27,14 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, adamw_update) + signsgd_update, signum_update) from ..ndarray import sparse from ..random import normal __all__ = [ 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', - 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register', 'AdamW' + 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' ] @@ -1018,70 +1018,6 @@ class ccSGD(SGD): def __init__(self, *args, **kwargs): super(ccSGD, self).__init__(*args, **kwargs) -@register -class AdamW(Optimizer): - """The Adam optimizer with fixed weight decay regularization. - - This class implements the optimizer described in *Fixing Weight Decay - Regularization in Adam*, available at https://arxiv.org/abs/1711.05101. - - Note that this is different from the original Adam optimizer which adds L2 - regularization on the weights to the loss: it regularizes weights with large - gradients more than L2 regularization would, which was shown to yield better - training loss and generalization error in the paper above. - - Updates are applied by:: - - rescaled_grad = clip(grad * rescale_grad, clip_gradient) - m = beta1 * m + (1 - beta1) * rescaled_grad - v = beta2 * v + (1 - beta2) * (rescaled_grad**2) - w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w) - - This optimizer accepts the following parameters in addition to those accepted - by :class:`.Optimizer`. - - For details of the update algorithm, see :class:`~mxnet.ndarray.adamw_update`. - - Parameters - ---------- - beta1 : float, optional - Exponential decay rate for the first moment estimates. - beta2 : float, optional - Exponential decay rate for the second moment estimates. - epsilon : float, optional - Small value to avoid division by 0. - """ - def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - **kwargs): - super(AdamW, self).__init__(learning_rate=learning_rate, **kwargs) - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - - def create_state(self, index, weight): - return (zeros(weight.shape, weight.context, dtype=weight.dtype), #mean - zeros(weight.shape, weight.context, dtype=weight.dtype)) #variance - - def update(self, index, weight, grad, state): - assert(isinstance(weight, NDArray)) - assert(isinstance(grad, NDArray)) - self._update_count(index) - lr = self._get_lr(index) - wd = self._get_wd(index) - - t = self._index_update_count[index] - coef1 = 1. - self.beta1**t - coef2 = 1. - self.beta2**t - lr *= math.sqrt(coef2)/coef1 - - kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, - 'rescale_grad': self.rescale_grad} - if self.clip_gradient: - kwargs['clip_gradient'] = self.clip_gradient - - mean, var = state - adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs) - @register class Adam(Optimizer): """The Adam optimizer. diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h new file mode 100644 index 000000000000..b450a91e26b9 --- /dev/null +++ b/src/operator/contrib/adamw-inl.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file optimizer_op-inl.h + * \brief Optimizer operators + * \author Haibin Lin + */ +#ifndef MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" + +namespace mxnet { +namespace op { + +struct AdamWParam : public dmlc::Parameter { + float lr; + float beta1; + float beta2; + float epsilon; + float wd; + float sched_mult; + float rescale_grad; + float clip_gradient; + DMLC_DECLARE_PARAMETER(AdamWParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-8f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(sched_mult) + .describe("Learning rate schedule multiplier"); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +/* + * \brief adam_w update. + */ +template +inline void AdamWUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const AdamWParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + + grad = scalar(param.rescale_grad) * grad; + if (param.clip_gradient >= 0.0f) { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * + F(grad, DType(param.clip_gradient)); + var = scalar(param.beta2)*var + scalar(1.f-param.beta2)*F( + F(grad, DType(param.clip_gradient))); + } else { + mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; + var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); + } + Assign(out, req[0], + weight - + scalar(param.sched_mult) * (scalar(param.lr) * + mean / (F(var) + scalar(param.epsilon)) + + (scalar(param.wd) * weight))); + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ADAMW_INL_H_ diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc new file mode 100644 index 000000000000..38f089c4108c --- /dev/null +++ b/src/operator/contrib/adamw.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file optimizer_op.cc + * \brief Optimizer operators + * \author Haibin Lin + */ +#include "./adamw-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(AdamWParam); + +NNVM_REGISTER_OP(_contrib_adamw_update) +.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of +Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. + +Adam update consists of the following steps, where g represents gradient and m, v +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1})\\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) + w -= sched_mult * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", AdamWUpdate) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_arguments(AdamWParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu new file mode 100644 index 000000000000..b7452f861e2d --- /dev/null +++ b/src/operator/contrib/adamw.cu @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file adamw.cu + * \brief Optimizer operators + * \author Haibin Lin + */ +#include "./adamw-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_contrib_adamw_update) +.set_attr("FCompute", AdamWUpdate); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 063adaed8697..6c44f99c1443 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -472,16 +472,15 @@ are 1st and 2nd order moment estimates (mean and variance). .. math:: - g_t = \nabla J(W_{t-1}) + wd W_{t-1}\\ + g_t = \nabla J(W_{t-1})\\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ W_t = W_{t-1} - \alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon } It updates the weights using:: - g = grad + wd*w - m = beta1*m + (1-beta1)*g - v = beta2*v + (1-beta2)*(g**2) + m = beta1*m + (1-beta1)*grad + v = beta2*v + (1-beta2)*(grad**2) w += - learning_rate * m / (sqrt(v) + epsilon) However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage @@ -508,7 +507,7 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; }) -.set_attr("FCompute", AdamUpdate) +.set_attr("FCompute", AdamUpdate) .set_attr("FComputeEx", AdamUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") @@ -516,42 +515,6 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_argument("var", "NDArray-or-Symbol", "Moving variance") .add_arguments(AdamParam::__FIELDS__()); -NNVM_REGISTER_OP(adamw_update) -.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of -Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. - -Adam update consists of the following steps, where g represents gradient and m, v -are 1st and 2nd order moment estimates (mean and variance). - -.. math:: - - g_t = \nabla J(W_{t-1})\\ - m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ - v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\ - W_t = W_{t-1} - \alpha (\frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1}) - -It updates the weights using:: - - m = beta1*m + (1-beta1)*grad - v = beta2*v + (1-beta2)*(grad**2) - w += - learning_rate * (m / (sqrt(v) + epsilon) + w*wd) - -)code" ADD_FILELINE) -.set_num_inputs(4) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<4, 1>) -.set_attr("FInferType", ElemwiseType<4, 1>) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3}; - }) -.set_attr("FCompute", AdamUpdate) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_argument("mean", "NDArray-or-Symbol", "Moving mean") -.add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_arguments(AdamParam::__FIELDS__()); NNVM_REGISTER_OP(rmsprop_update) .describe(R"code(Update function for `RMSProp` optimizer. diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 8de746a48462..0fd2ca83fda4 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -246,12 +246,9 @@ NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); NNVM_REGISTER_OP(adam_update) -.set_attr("FCompute", AdamUpdate) +.set_attr("FCompute", AdamUpdate) .set_attr("FComputeEx", AdamUpdateEx); -NNVM_REGISTER_OP(adamw_update) -.set_attr("FCompute", AdamUpdate); - NNVM_REGISTER_OP(rmsprop_update) .set_attr("FCompute", RMSPropUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index d5f680c84fa0..eb33f9b5217e 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -613,93 +613,6 @@ def test_adam(): dtype, w_stype='default', g_stype='row_sparse', rtol=1e-4, atol=2e-5) -# ADAMW -class PyAdamW(mx.optimizer.Optimizer): - """python reference implemenation of AdamW""" - def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, - **kwargs): - super(PyAdamW, self).__init__(learning_rate=learning_rate, **kwargs) - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - - def create_state(self, index, weight): - """Create additional optimizer state: mean, variance - - Parameters - ---------- - weight : NDArray - The weight data - - """ - return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean - mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance - - def update(self, index, weight, grad, state): - """Update the parameters. - - Parameters - ---------- - index : int - An unique integer key used to index the parameters - - weight : NDArray - weight ndarray - - grad : NDArray - grad ndarray - - state : NDArray or other objects returned by init_state - The auxiliary state used in optimization. - """ - lr = self._get_lr(index) - self._update_count(index) - - t = self._index_update_count[index] - mean, variance = state - - wd = self._get_wd(index) - coef1 = 1. - self.beta1**t - coef2 = 1. - self.beta2**t - lr *= math.sqrt(coef2)/coef1 - - grad *= self.rescale_grad - # clip gradients - if self.clip_gradient is not None: - mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad) - # update mean - mean *= self.beta1 - mean += grad * (1. - self.beta1) - # update variance - variance *= self.beta2 - variance += (1 - self.beta2) * mx.nd.square(grad, out=grad) - # update weight - weight -= lr * (mean/(mx.nd.sqrt(variance) + self.epsilon) + wd * weight) - -@with_seed() -def test_adamw(): - opt1 = PyAdamW - opt2 = mx.optimizer.AdamW - shape = (3, 4, 5) - cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] - rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] - wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] - mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}] - for dtype in [np.float16, np.float32, np.float64]: - for cg_option in cg_options: - for rg_option in rg_options: - for wd_option in wd_options: - for mp_option in mp_options: - kwarg = {} - kwarg.update(cg_option) - kwarg.update(rg_option) - kwarg.update(wd_option) - kwarg.update(mp_option) - if (dtype == np.float16 and - ('multi_precision' not in kwarg or not kwarg['multi_precision'])): - continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) - # AdaMax class PyAdamax(mx.optimizer.Optimizer): """The python reference of AdaMax optimizer. From 7b65e5f2da181302886dcc235dd6a20d8a8c064d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 27 Dec 2018 23:09:08 +0000 Subject: [PATCH 3/3] rename parameter --- src/operator/contrib/adamw-inl.h | 6 +++--- src/operator/contrib/adamw.cc | 2 +- src/operator/optimizer_op-inl.h | 30 ++++++++---------------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index b450a91e26b9..3d76b33ae765 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -47,7 +47,7 @@ struct AdamWParam : public dmlc::Parameter { float beta2; float epsilon; float wd; - float sched_mult; + float eta; float rescale_grad; float clip_gradient; DMLC_DECLARE_PARAMETER(AdamWParam) { @@ -67,7 +67,7 @@ struct AdamWParam : public dmlc::Parameter { .describe("Weight decay augments the objective function with a " "regularization term that penalizes large weights. " "The penalty scales with the square of the magnitude of each weight."); - DMLC_DECLARE_FIELD(sched_mult) + DMLC_DECLARE_FIELD(eta) .describe("Learning rate schedule multiplier"); DMLC_DECLARE_FIELD(rescale_grad) .set_default(1.0f) @@ -113,7 +113,7 @@ inline void AdamWUpdate(const nnvm::NodeAttrs& attrs, } Assign(out, req[0], weight - - scalar(param.sched_mult) * (scalar(param.lr) * + scalar(param.eta) * (scalar(param.lr) * mean / (F(var) + scalar(param.epsilon)) + (scalar(param.wd) * weight))); }); diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index 38f089c4108c..94623fe08a9e 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -48,7 +48,7 @@ It updates the weights using:: m = beta1*m + (1-beta1)*grad v = beta2*v + (1-beta2)*(grad**2) - w -= sched_mult * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) + w -= eta * (learning_rate * m / (sqrt(v) + epsilon) + w * wd) )code" ADD_FILELINE) .set_num_inputs(4) diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 46b477520bb5..9251b8614806 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -837,10 +837,7 @@ struct AdamParam : public dmlc::Parameter { } }; -/* - * \brief adam and adam_w update. Set decoupled=True for adam_w. - */ -template +template inline void AdamUpdate(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -858,12 +855,9 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, Tensor var = inputs[3].FlatTo2D(s); Tensor out = outputs[0].FlatTo2D(s); - if (decoupled) { - grad = scalar(param.rescale_grad) * grad; - } else { - grad = scalar(param.rescale_grad) * grad + - scalar(param.wd) * weight; - } + grad = scalar(param.rescale_grad) * grad + + scalar(param.wd) * weight; + if (param.clip_gradient >= 0.0f) { mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * F(grad, DType(param.clip_gradient)); @@ -873,18 +867,10 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs, mean = scalar(param.beta1)*mean + scalar(1.f-param.beta1) * grad; var = scalar(param.beta2)*var + scalar(1.f-param.beta2) * F(grad); } - if (decoupled) { - Assign(out, req[0], - weight - - scalar(param.lr) * (mean / - (F(var) + scalar(param.epsilon)) + - (scalar(param.wd) * weight))); - } else { - Assign(out, req[0], - weight - - scalar(param.lr) * mean / - (F(var) + scalar(param.epsilon))); - } + Assign(out, req[0], + weight - + scalar(param.lr) * mean / + (F(var) + scalar(param.epsilon))); }); }