Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Acharya committed May 20, 2019
1 parent 4ac083b commit 11895e5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
6 changes: 2 additions & 4 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,10 +1099,8 @@ def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state, multi_precision=False)

def update_multi_precision(self, index, weight, grad, state):
if not isinstance(index, (tuple, list)):
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
else:
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \
and isinstance(state, (tuple, list))
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)

Expand Down
26 changes: 17 additions & 9 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,12 +708,11 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a


NNVM_REGISTER_OP(nag_update)
MXNET_ADD_SPARSE_OP_ALIAS(nag_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
NAG update consists of the following steps,
It updates the weights using the following formula,
weight = weight - (lr * (grad + wd * weight))
state = momentum * state + grad + wd * weight
weight = weight - (lr * (grad + momentum * state))
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand All @@ -727,8 +726,19 @@ weight = weight - (lr * (grad + momentum * state))


NNVM_REGISTER_OP(nag_mom_update)
MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update)
.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
It updates the weights using the following formula,
.. math::
v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
W_t = W_{t-1} - v_t
Where
:math:`\eta` is the learning rate of the optimizer
:math:`\gamma` is the decay rate of the momentum estimate
:math:`\v_t` is the update vector at time step `t`
:math:`\W_t` is the weight vector at time step `t`
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
Expand All @@ -747,8 +757,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(nag_mom_update)


NNVM_REGISTER_OP(mp_nag_update)
MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update)
.describe(R"code(Multi-precision NAG update.
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
.set_num_inputs(3)
.set_num_outputs(1)
Expand All @@ -767,8 +776,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_update)


NNVM_REGISTER_OP(mp_nag_mom_update)
MXNET_ADD_SPARSE_OP_ALIAS(mp_nag_mom_update)
.describe(R"code(Multi-precision NAG update.
.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
Expand Down

0 comments on commit 11895e5

Please sign in to comment.