-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Signum optimizer #9220
Signum optimizer #9220
Changes from 4 commits
60f00a7
9341fdd
657fd1f
dc6fb2d
46e45ff
6d7525a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,8 @@ | |
from .base import py_str | ||
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs) | ||
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, | ||
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update) | ||
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, | ||
signsgd_update, signum_update) | ||
from .ndarray import _internal | ||
from .ndarray import op | ||
from .ndarray import sparse | ||
|
@@ -534,6 +535,66 @@ def update_multi_precision(self, index, weight, grad, state): | |
self._update_impl(index, weight, grad, state, | ||
multi_precision=use_multi_precision) | ||
|
||
@register | ||
class Signum(Optimizer): | ||
"""The Signum optimizer that takes the sign of gradient or momentum. | ||
|
||
The optimizer updates the weight by: | ||
|
||
rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight | ||
state = momentum * state + (1-momentum)*rescaled_grad | ||
weight = (1 - lr * wd_lh) * weight - lr * sign(state) | ||
|
||
See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf | ||
|
||
For details of the update algorithm see | ||
:class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. | ||
|
||
This optimizer accepts the following parameters in addition to those accepted | ||
by :class:`.Optimizer`. | ||
|
||
Parameters | ||
---------- | ||
momentum : float, optional | ||
The momentum value. | ||
wd_lh : float, optional | ||
The amount of decoupled weight decay regularization. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also add a reference/link to the original paper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added the temp link to pdf hosted on jeremy's site. will update to arxiv or a published version when they are ready. |
||
""" | ||
def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs): | ||
super(Signum, self).__init__(learning_rate=learning_rate, **kwargs) | ||
self.momentum = momentum | ||
self.wd_lh = wd_lh | ||
|
||
def create_state(self, index, weight): | ||
momentum = None | ||
if self.momentum != 0.0: | ||
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) | ||
return momentum | ||
|
||
def _update_impl(self, index, weight, grad, state): | ||
assert(isinstance(weight, NDArray)) | ||
assert(isinstance(grad, NDArray)) | ||
self._update_count(index) | ||
lr = self._get_lr(index) | ||
wd = self._get_wd(index) | ||
|
||
kwargs = {'rescale_grad': self.rescale_grad} | ||
if self.momentum > 0: | ||
kwargs['momentum'] = self.momentum | ||
if self.clip_gradient: | ||
kwargs['clip_gradient'] = self.clip_gradient | ||
if self.wd_lh: | ||
kwargs['wd_lh'] = self.wd_lh | ||
|
||
if state is not None: | ||
signum_update(weight, grad, state, out=weight, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. call these signum_momentum_update and signum_update to be consistent with others There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RE: naming.
Unless we change the names in our paper, let's keep them the way they are. |
||
lr=lr, wd=wd, **kwargs) | ||
else: | ||
signsgd_update(weight, grad, out=weight, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, signsgd takes the sign of stochastic gradient. |
||
lr=lr, wd=wd, **kwargs) | ||
|
||
def update(self, index, weight, grad, state): | ||
self._update_impl(index, weight, grad, state) | ||
|
||
@register | ||
class FTML(Optimizer): | ||
|
@@ -702,8 +763,7 @@ def update(self, index, weight, grad, state): | |
if self.clip_gradient is not None: | ||
grad = clip(grad, -self.clip_gradient, self.clip_gradient) | ||
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), | ||
shape=weight.shape, | ||
ctx=weight.context) | ||
weight.shape, weight.context) | ||
|
||
|
||
@register # pylint: disable=invalid-name | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> { | |
} | ||
}; | ||
|
||
|
||
struct SGDKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, | ||
|
@@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> { | |
} | ||
}; | ||
|
||
|
||
struct SGDMomKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, | ||
|
@@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs, | |
} | ||
} | ||
|
||
|
||
// Implementation for signSGD and Signum | ||
|
||
struct SignSGDParam : public dmlc::Parameter<SignSGDParam> { | ||
float lr; | ||
float wd; | ||
float rescale_grad; | ||
float clip_gradient; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has an effect on Signum. Because it will lead to different result whether we use gradient or clipped gradient for calculating momentum. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. Thanks for the explanation! |
||
DMLC_DECLARE_PARAMETER(SignSGDParam) { | ||
DMLC_DECLARE_FIELD(lr) | ||
.describe("Learning rate"); | ||
DMLC_DECLARE_FIELD(wd) | ||
.set_default(0.0f) | ||
.describe("Weight decay augments the objective function with a " | ||
"regularization term that penalizes large weights. " | ||
"The penalty scales with the square of the magnitude of each weight."); | ||
DMLC_DECLARE_FIELD(rescale_grad) | ||
.set_default(1.0f) | ||
.describe("Rescale gradient to grad = rescale_grad*grad."); | ||
DMLC_DECLARE_FIELD(clip_gradient) | ||
.set_default(-1.0f) | ||
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " | ||
"If clip_gradient <= 0, gradient clipping is turned off. " | ||
"grad = max(min(grad, clip_gradient), -clip_gradient)."); | ||
} | ||
}; | ||
|
||
|
||
struct SignSGDKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data, | ||
const DType* grad_data, const DType param_clip_gradient, | ||
const DType param_lr, const DType param_wd, const DType param_rescale_grad, | ||
const OpReqType req) { | ||
|
||
// param_clip_gradient has no effect for SignSGD | ||
KERNEL_ASSIGN(out_data[i], req, | ||
(1.f-param_lr*param_wd)*weight_data[i] | ||
- (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0))); | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mxnet_op; | ||
const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed); | ||
Stream<xpu>* s = ctx.get_stream<xpu>(); | ||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); | ||
Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_, | ||
grad.dptr_, static_cast<DType>(param.clip_gradient), | ||
static_cast<DType>(param.lr), static_cast<DType>(param.wd), | ||
static_cast<DType>(param.rescale_grad), req[0]); | ||
}); | ||
} | ||
|
||
|
||
struct SignumParam : public dmlc::Parameter<SignumParam> { | ||
float lr; | ||
float momentum; | ||
float wd; | ||
float rescale_grad; | ||
float clip_gradient; | ||
float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter | ||
DMLC_DECLARE_PARAMETER(SignumParam) { | ||
DMLC_DECLARE_FIELD(lr) | ||
.describe("Learning rate"); | ||
DMLC_DECLARE_FIELD(momentum) | ||
.set_default(0.0f) | ||
.describe("The decay rate of momentum estimates at each epoch."); | ||
DMLC_DECLARE_FIELD(wd) | ||
.set_default(0.0f) | ||
.describe("Weight decay augments the objective function with a " | ||
"regularization term that penalizes large weights. " | ||
"The penalty scales with the square of the magnitude of each weight."); | ||
DMLC_DECLARE_FIELD(rescale_grad) | ||
.set_default(1.0f) | ||
.describe("Rescale gradient to grad = rescale_grad*grad."); | ||
DMLC_DECLARE_FIELD(clip_gradient) | ||
.set_default(-1.0f) | ||
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " | ||
"If clip_gradient <= 0, gradient clipping is turned off. " | ||
"grad = max(min(grad, clip_gradient), -clip_gradient)."); | ||
DMLC_DECLARE_FIELD(wd_lh) | ||
.set_default(0.0f) | ||
.describe("The amount of weight decay that does not go into gradient/momentum calculations" | ||
"otherwise do weight decay algorithmically only."); | ||
} | ||
}; | ||
|
||
struct SignumKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data, | ||
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum, | ||
const DType param_lr, const DType param_wd, const DType param_rescale_grad, | ||
const DType param_wd_lh, const OpReqType req) { | ||
if (param_clip_gradient >= 0.0f) { | ||
mom_data[i] = param_momentum*mom_data[i] | ||
- (1-param_momentum)*param_wd*weight_data[i] | ||
- (1-param_momentum) | ||
*mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient); | ||
} else { | ||
mom_data[i] = param_momentum*mom_data[i] | ||
- (1-param_momentum)*param_wd*weight_data[i] | ||
- (1-param_momentum)*param_rescale_grad*grad_data[i]; | ||
} | ||
KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i] | ||
+ (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0))); | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
inline void SignumUpdate(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mxnet_op; | ||
SignumParam param = nnvm::get<SignumParam>(attrs.parsed); | ||
Stream<xpu>* s = ctx.get_stream<xpu>(); | ||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); | ||
Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_, | ||
grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), | ||
static_cast<DType>(param.lr), static_cast<DType>(param.wd), | ||
static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]); | ||
}); | ||
} | ||
|
||
|
||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's wd_lh? Is it from the original paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is an alternative weight decay. See the descriptions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
wd_lh
is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation