Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Signum optimizer #9220

Merged
merged 6 commits into from
Jan 12, 2018
Merged

Signum optimizer #9220

merged 6 commits into from
Jan 12, 2018

Conversation

yuxiangw
Copy link
Contributor

@yuxiangw yuxiangw commented Dec 28, 2017

Description

Added the C++ implementation of the Signum optimizer.

Bernstein, Wang, Azizzadenesheli and Anandkumar (2017) "The Signum optimiser: a theory of momentum in quantised stochastic optimisation"
Link to pdf

What's also included is the implementation of an option to do the alternative version of weight decay regularization due to Loshchilov and Hutter via option 'wd_lh'.
"Fixing Weight Decay Regularization in Adam"
Link to arxiv

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Added the Signum optimizer to mxnet's list of optimizers
  • A special case is SignSGD optimizer, a stand-alone implementation whenever "momentum" is set to 0.

Comments

  • TODO1: add sparse matrix support for this optimizer
  • TODO2: Take advantage of the 1-bit gradient compression interpretation of SignSGD and Signum.
  • TODO3: Adding 'wd_lh' support for Adam and other adaptive gradient optimizers.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!! Please see detailed comments in code.

@@ -28,6 +28,14 @@
namespace mxnet {
namespace op {

NNVM_REGISTER_OP(signsgd_update)
.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
// .set_attr<FComputeEx>("FComputeEx<gpu>", SignSGDUpdateEx<gpu>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove unused lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return std::vector<uint32_t>{2};
})
.set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
// .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please removed unused lines (also the ones in line 42, 65)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -57,6 +58,10 @@ class Optimizer(object):
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.

wd_lh: float, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a change in the Optimizer class constructor. Why is this changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that to the constructor at some point, cuz wd_lh is something more generally applicable to other algorithms too (in particular, Adam).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed that line.


** Sparse matrix not supported for this optimizer yet.

If weight and momentum are both of ``row_sparse`` storage type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather remove the line 81-87 since sparse update is not supported anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.

** Sparse matrix not supported for this optimizer yet.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if sentence starting with ** renders well in API doc. What about adding a "note" section like rint?
/~https://github.com/apache/incubator-mxnet/blob/ae70769c8e35cc178bf7dd9dba35386c13394394/src/operator/tensor/elemwise_unary_op_basic.cc#L432-L434
Also, term "sparse ndarray" instead of "sparse matrix" is preferred :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


NNVM_REGISTER_OP(signsgd_update)
// MXNET_ADD_SPARSE_OP_ALIAS(signsgd_update)
.describe(R"code(Update function for SignSGDoptimizer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: SignSGD optimizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. and added the math description block similar to other optimizers.

weight = weight - learning_rate * sign(gradient)


** Sparse matrix not supported for this optimizer yet.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for documentation rendering and FInferStorageType in signum_update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


@register
class Signum(Optimizer):
"""The SGD optimizer with momentum and weight decay.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one line summary should also mention it only takes the sign. Otherwise the readers don't know it until they see line 547

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added details to the doc accordingly.

float lr;
float wd;
float rescale_grad;
float clip_gradient;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the clip_gradient param has no effect on both SignSGD and Signum, can we just remove this param from signsgd_update and signum_update? That would also simply the c++ kernels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has an effect on Signum. Because it will lead to different result whether we use gradient or clipped gradient for calculating momentum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Thanks for the explanation!

momentum : float, optional
The momentum value.
wd_lh : float, optitional
The amount of decoupled weight decay regularization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a reference/link to the original paper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the temp link to pdf hosted on jeremy's site. will update to arxiv or a published version when they are ready.

@eric-haibin-lin
Copy link
Member

There are new conflicts now. Do you mind resolving them again?
BTW - the files under cpp-package are only needed if you use cpp as front end to train networks. Do you actually need it?

@yuxiangw
Copy link
Contributor Author

yuxiangw commented Jan 8, 2018

Done fixing the conflicts.

@eric-haibin-lin
Copy link
Member

@lx75249 could you help review the code for cpp-package?

@conopt
Copy link
Contributor

conopt commented Jan 9, 2018

@eric-haibin-lin LGTM

signum_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
signsgd_update(weight, grad, out=weight,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, signsgd takes the sign of stochastic gradient.


rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wd_lh? Is it from the original paper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an alternative weight decay. See the descriptions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since wd_lh is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation

kwargs['wd_lh'] = self.wd_lh

if state is not None:
signum_update(weight, grad, state, out=weight,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call these signum_momentum_update and signum_update to be consistent with others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: naming.

  • signum means SIGN momentUM. So the semantics of the momentum is already in there. -
  • SignSGD is the special case of Signum that goes without momentum. And it has been used before.

Unless we change the names in our paper, let's keep them the way they are.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One final comment. Otherwise LGTM. Thanks for the contribution!


rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since wd_lh is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation

@yuxiangw
Copy link
Contributor Author

Added the reference the documentation as suggested. Thanks guys for reviewing the PR!

@piiswrong piiswrong merged commit 5251b86 into apache:master Jan 12, 2018
@piiswrong
Copy link
Contributor

Thanks

CodingCat pushed a commit to CodingCat/mxnet that referenced this pull request Jan 16, 2018
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
yuxiangw added a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
@yuxiangw yuxiangw mentioned this pull request Jan 25, 2018
7 tasks
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* the c++ version of signum and signsgd optimizer

* optimizer signum, tested working with mac on cpuusing mnist

* unit test for signum

* fix lint and incorporate haibin's code review

* rerun jenkins

* adding link to the Loshachilov and Hutter to the documentation
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants