This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
AdamW operator (Fixing Weight Decay Regularization in Adam) #13728
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,14 +27,14 @@ | |
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) | ||
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, | ||
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, | ||
signsgd_update, signum_update) | ||
signsgd_update, signum_update, adamw_update) | ||
from ..ndarray import sparse | ||
from ..random import normal | ||
|
||
__all__ = [ | ||
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', | ||
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', | ||
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' | ||
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register', 'AdamW' | ||
] | ||
|
||
|
||
|
@@ -1018,6 +1018,70 @@ class ccSGD(SGD): | |
def __init__(self, *args, **kwargs): | ||
super(ccSGD, self).__init__(*args, **kwargs) | ||
|
||
@register | ||
class AdamW(Optimizer): | ||
"""The Adam optimizer with fixed weight decay regularization. | ||
|
||
This class implements the optimizer described in *Fixing Weight Decay | ||
Regularization in Adam*, available at https://arxiv.org/abs/1711.05101. | ||
|
||
Note that this is different from the original Adam optimizer which adds L2 | ||
regularization on the weights to the loss: it regularizes weights with large | ||
gradients more than L2 regularization would, which was shown to yield better | ||
training loss and generalization error in the paper above. | ||
|
||
Updates are applied by:: | ||
|
||
rescaled_grad = clip(grad * rescale_grad, clip_gradient) | ||
m = beta1 * m + (1 - beta1) * rescaled_grad | ||
v = beta2 * v + (1 - beta2) * (rescaled_grad**2) | ||
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w) | ||
|
||
This optimizer accepts the following parameters in addition to those accepted | ||
by :class:`.Optimizer`. | ||
|
||
For details of the update algorithm, see :class:`~mxnet.ndarray.adamw_update`. | ||
|
||
Parameters | ||
---------- | ||
beta1 : float, optional | ||
Exponential decay rate for the first moment estimates. | ||
beta2 : float, optional | ||
Exponential decay rate for the second moment estimates. | ||
epsilon : float, optional | ||
Small value to avoid division by 0. | ||
""" | ||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, | ||
**kwargs): | ||
super(AdamW, self).__init__(learning_rate=learning_rate, **kwargs) | ||
self.beta1 = beta1 | ||
self.beta2 = beta2 | ||
self.epsilon = epsilon | ||
|
||
def create_state(self, index, weight): | ||
return (zeros(weight.shape, weight.context, dtype=weight.dtype), #mean | ||
zeros(weight.shape, weight.context, dtype=weight.dtype)) #variance | ||
|
||
def update(self, index, weight, grad, state): | ||
assert(isinstance(weight, NDArray)) | ||
assert(isinstance(grad, NDArray)) | ||
self._update_count(index) | ||
lr = self._get_lr(index) | ||
wd = self._get_wd(index) | ||
|
||
t = self._index_update_count[index] | ||
coef1 = 1. - self.beta1**t | ||
coef2 = 1. - self.beta2**t | ||
lr *= math.sqrt(coef2)/coef1 | ||
|
||
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, | ||
'rescale_grad': self.rescale_grad} | ||
if self.clip_gradient: | ||
kwargs['clip_gradient'] = self.clip_gradient | ||
|
||
mean, var = state | ||
adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we set |
||
|
||
@register | ||
class Adam(Optimizer): | ||
"""The Adam optimizer. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the paper, it has two learning rates. An alpha before m / (sqrt(v) + epsilon).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The issue is that the learning rate and schedule multiplier is not decoupled in MXNet. Here
learning_rate
is effectivelyeta_t * alpha
in the paper andwd
actually needs to be set asw / alpha
. In another wordwd
can be rescaled properly so that it does exactly the same thing in the paper. Would this be acceptable? Is so maybe I can move this to contrib for the momentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's acceptable as long as the
wd
is set correctly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought I think it's better to keep it consistent with the paper