-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!! Please see detailed comments in code.
src/operator/optimizer_op.cu
Outdated
@@ -28,6 +28,14 @@ | |||
namespace mxnet { | |||
namespace op { | |||
|
|||
NNVM_REGISTER_OP(signsgd_update) | |||
.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>); | |||
// .set_attr<FComputeEx>("FComputeEx<gpu>", SignSGDUpdateEx<gpu>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove unused lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/optimizer_op.cc
Outdated
return std::vector<uint32_t>{2}; | ||
}) | ||
.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>) | ||
// .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please removed unused lines (also the ones in line 42, 65)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/mxnet/optimizer.py
Outdated
@@ -57,6 +58,10 @@ class Optimizer(object): | |||
The weight decay (or L2 regularization) coefficient. Modifies objective | |||
by adding a penalty for having large weights. | |||
|
|||
wd_lh: float, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a change in the Optimizer
class constructor. Why is this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added that to the constructor at some point, cuz wd_lh is something more generally applicable to other algorithms too (in particular, Adam).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed that line.
src/operator/optimizer_op.cc
Outdated
|
||
** Sparse matrix not supported for this optimizer yet. | ||
|
||
If weight and momentum are both of ``row_sparse`` storage type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather remove the line 81-87 since sparse update is not supported anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/optimizer_op.cc
Outdated
|
||
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. | ||
|
||
** Sparse matrix not supported for this optimizer yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if sentence starting with **
renders well in API doc. What about adding a "note" section like rint?
/~https://github.com/apache/incubator-mxnet/blob/ae70769c8e35cc178bf7dd9dba35386c13394394/src/operator/tensor/elemwise_unary_op_basic.cc#L432-L434
Also, term "sparse ndarray" instead of "sparse matrix" is preferred :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/optimizer_op.cc
Outdated
|
||
NNVM_REGISTER_OP(signsgd_update) | ||
// MXNET_ADD_SPARSE_OP_ALIAS(signsgd_update) | ||
.describe(R"code(Update function for SignSGDoptimizer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: SignSGD optimizer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. and added the math description block similar to other optimizers.
src/operator/optimizer_op.cc
Outdated
weight = weight - learning_rate * sign(gradient) | ||
|
||
|
||
** Sparse matrix not supported for this optimizer yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment for documentation rendering and FInferStorageType
in signum_update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
python/mxnet/optimizer.py
Outdated
|
||
@register | ||
class Signum(Optimizer): | ||
"""The SGD optimizer with momentum and weight decay. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one line summary should also mention it only takes the sign. Otherwise the readers don't know it until they see line 547
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added details to the doc accordingly.
float lr; | ||
float wd; | ||
float rescale_grad; | ||
float clip_gradient; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the clip_gradient
param has no effect on both SignSGD and Signum, can we just remove this param from signsgd_update
and signum_update
? That would also simply the c++ kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has an effect on Signum. Because it will lead to different result whether we use gradient or clipped gradient for calculating momentum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. Thanks for the explanation!
python/mxnet/optimizer.py
Outdated
momentum : float, optional | ||
The momentum value. | ||
wd_lh : float, optitional | ||
The amount of decoupled weight decay regularization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add a reference/link to the original paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added the temp link to pdf hosted on jeremy's site. will update to arxiv or a published version when they are ready.
30d980c
to
955c7f0
Compare
There are new conflicts now. Do you mind resolving them again? |
955c7f0
to
dc6fb2d
Compare
Done fixing the conflicts. |
@lx75249 could you help review the code for cpp-package? |
@eric-haibin-lin LGTM |
signum_update(weight, grad, state, out=weight, | ||
lr=lr, wd=wd, **kwargs) | ||
else: | ||
signsgd_update(weight, grad, out=weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, signsgd takes the sign of stochastic gradient.
|
||
rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight | ||
state = momentum * state + (1-momentum)*rescaled_grad | ||
weight = (1 - lr * wd_lh) * weight - lr * sign(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's wd_lh? Is it from the original paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is an alternative weight decay. See the descriptions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since wd_lh
is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation
kwargs['wd_lh'] = self.wd_lh | ||
|
||
if state is not None: | ||
signum_update(weight, grad, state, out=weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call these signum_momentum_update and signum_update to be consistent with others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE: naming.
- signum means SIGN momentUM. So the semantics of the momentum is already in there. -
- SignSGD is the special case of Signum that goes without momentum. And it has been used before.
Unless we change the names in our paper, let's keep them the way they are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final comment. Otherwise LGTM. Thanks for the contribution!
|
||
rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight | ||
state = momentum * state + (1-momentum)*rescaled_grad | ||
weight = (1 - lr * wd_lh) * weight - lr * sign(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since wd_lh
is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation
Added the reference the documentation as suggested. Thanks guys for reviewing the PR! |
Thanks |
* the c++ version of signum and signsgd optimizer * optimizer signum, tested working with mac on cpuusing mnist * unit test for signum * fix lint and incorporate haibin's code review * rerun jenkins * adding link to the Loshachilov and Hutter to the documentation
* the c++ version of signum and signsgd optimizer * optimizer signum, tested working with mac on cpuusing mnist * unit test for signum * fix lint and incorporate haibin's code review * rerun jenkins * adding link to the Loshachilov and Hutter to the documentation
* the c++ version of signum and signsgd optimizer * optimizer signum, tested working with mac on cpuusing mnist * unit test for signum * fix lint and incorporate haibin's code review * rerun jenkins * adding link to the Loshachilov and Hutter to the documentation
* the c++ version of signum and signsgd optimizer * optimizer signum, tested working with mac on cpuusing mnist * unit test for signum * fix lint and incorporate haibin's code review * rerun jenkins * adding link to the Loshachilov and Hutter to the documentation
Description
Added the C++ implementation of the Signum optimizer.
Bernstein, Wang, Azizzadenesheli and Anandkumar (2017) "The Signum optimiser: a theory of momentum in quantised stochastic optimisation"
Link to pdf
What's also included is the implementation of an option to do the alternative version of weight decay regularization due to Loshchilov and Hutter via option 'wd_lh'.
"Fixing Weight Decay Regularization in Adam"
Link to arxiv
Checklist
Essentials
make lint
)Changes
Comments