Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Signum optimizer #9220

Merged
merged 6 commits into from
Jan 12, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cpp-package/include/mxnet-cpp/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ class SGDOptimizer : public Optimizer {
AtomicSymbolCreator mom_update_handle_;
};

class SignumOptimizer : public Optimizer {
public:
explicit SignumOptimizer(unsigned begin_num_update = 0);
std::string GetType() const override;
void Update(int index, NDArray weight, NDArray grad) override;
private:
virtual ~SignumOptimizer();
void CreateState_(int index, NDArray weight) override;
std::map<int, NDArray*> states_;
AtomicSymbolCreator update_handle_;
AtomicSymbolCreator mom_update_handle_;
};


class RMSPropOptimizer : public Optimizer {
public:
explicit RMSPropOptimizer(unsigned begin_num_update = 0);
Expand Down
64 changes: 64 additions & 0 deletions cpp-package/include/mxnet-cpp/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adadelta, AdaDeltaOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(signum, SignumOptimizer);
auto it = cmap().find(name);
if (it == cmap().end())
return nullptr;
Expand Down Expand Up @@ -200,6 +201,69 @@ inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
}
}

// inplementing Signum optimizer

inline SignumOptimizer::SignumOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("signsgd_update");
mom_update_handle_ = op_map()->GetSymbolCreator("signum_update");
}

inline std::string SignumOptimizer::GetType() const {
return "signum";
}

inline SignumOptimizer::~SignumOptimizer() {
for (auto &it : states_) {
delete it.second;
}
}

inline void SignumOptimizer::Update(int index, NDArray weight, NDArray grad) {
if (states_.count(index) == 0) {
CreateState_(index, weight);
}

params_["lr"] = std::to_string(GetLR_(index));
params_["wd"] = std::to_string(GetWD_(index));
UpdateCount_(index);
auto keys = GetParamKeys_();
auto values = GetParamValues_();
CHECK_EQ(keys.size(), values.size());

NDArrayHandle inputs[3];
inputs[0] = weight.GetHandle();
inputs[1] = grad.GetHandle();

int num_outputs = 1;
NDArrayHandle output = weight.GetHandle();
NDArrayHandle *outputs = &output;

if (states_[index] == nullptr) {
MXImperativeInvoke(update_handle_, 2, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
} else {
inputs[2] = states_[index]->GetHandle();
MXImperativeInvoke(mom_update_handle_, 3, inputs,
&num_outputs, &outputs,
keys.size(), keys.data(), values.data());
}
}

inline void SignumOptimizer::CreateState_(int index, NDArray weight) {
if (params_.count("momentum") == 0) {
states_[index] = nullptr;
} else {
states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
*states_[index] = 0;
}
}

// finish implementing Signum



inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
Expand Down
66 changes: 63 additions & 3 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update)
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
from .ndarray import _internal
from .ndarray import op
from .ndarray import sparse
Expand Down Expand Up @@ -534,6 +535,66 @@ def update_multi_precision(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state,
multi_precision=use_multi_precision)

@register
class Signum(Optimizer):
"""The Signum optimizer that takes the sign of gradient or momentum.

The optimizer updates the weight by:

rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wd_lh? Is it from the original paper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an alternative weight decay. See the descriptions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since wd_lh is new, I suggest put a reference link to the original paper by Loshchilov and Frank Hutter in the documentation


See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf

For details of the update algorithm see
:class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.

This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.

Parameters
----------
momentum : float, optional
The momentum value.
wd_lh : float, optional
The amount of decoupled weight decay regularization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a reference/link to the original paper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the temp link to pdf hosted on jeremy's site. will update to arxiv or a published version when they are ready.

"""
def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
self.momentum = momentum
self.wd_lh = wd_lh

def create_state(self, index, weight):
momentum = None
if self.momentum != 0.0:
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
return momentum

def _update_impl(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

kwargs = {'rescale_grad': self.rescale_grad}
if self.momentum > 0:
kwargs['momentum'] = self.momentum
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
if self.wd_lh:
kwargs['wd_lh'] = self.wd_lh

if state is not None:
signum_update(weight, grad, state, out=weight,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call these signum_momentum_update and signum_update to be consistent with others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: naming.

  • signum means SIGN momentUM. So the semantics of the momentum is already in there. -
  • SignSGD is the special case of Signum that goes without momentum. And it has been used before.

Unless we change the names in our paper, let's keep them the way they are.

lr=lr, wd=wd, **kwargs)
else:
signsgd_update(weight, grad, out=weight,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, signsgd takes the sign of stochastic gradient.

lr=lr, wd=wd, **kwargs)

def update(self, index, weight, grad, state):
self._update_impl(index, weight, grad, state)

@register
class FTML(Optimizer):
Expand Down Expand Up @@ -702,8 +763,7 @@ def update(self, index, weight, grad, state):
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
shape=weight.shape,
ctx=weight.context)
weight.shape, weight.context)


@register # pylint: disable=invalid-name
Expand Down
142 changes: 142 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
}
};


struct SGDKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
Expand Down Expand Up @@ -228,6 +229,7 @@ struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
}
};


struct SGDMomKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
Expand Down Expand Up @@ -1281,6 +1283,146 @@ inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
}
}


// Implementation for signSGD and Signum

struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
float lr;
float wd;
float rescale_grad;
float clip_gradient;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the clip_gradient param has no effect on both SignSGD and Signum, can we just remove this param from signsgd_update and signum_update? That would also simply the c++ kernels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has an effect on Signum. Because it will lead to different result whether we use gradient or clipped gradient for calculating momentum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Thanks for the explanation!

DMLC_DECLARE_PARAMETER(SignSGDParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};


struct SignSGDKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const OpReqType req) {

// param_clip_gradient has no effect for SignSGD
KERNEL_ASSIGN(out_data[i], req,
(1.f-param_lr*param_wd)*weight_data[i]
- (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
}
};

template<typename xpu>
inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), req[0]);
});
}


struct SignumParam : public dmlc::Parameter<SignumParam> {
float lr;
float momentum;
float wd;
float rescale_grad;
float clip_gradient;
float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
DMLC_DECLARE_PARAMETER(SignumParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(momentum)
.set_default(0.0f)
.describe("The decay rate of momentum estimates at each epoch.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
DMLC_DECLARE_FIELD(wd_lh)
.set_default(0.0f)
.describe("The amount of weight decay that does not go into gradient/momentum calculations"
"otherwise do weight decay algorithmically only.");
}
};

struct SignumKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mom_data, const DType* weight_data,
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd, const DType param_rescale_grad,
const DType param_wd_lh, const OpReqType req) {
if (param_clip_gradient >= 0.0f) {
mom_data[i] = param_momentum*mom_data[i]
- (1-param_momentum)*param_wd*weight_data[i]
- (1-param_momentum)
*mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
} else {
mom_data[i] = param_momentum*mom_data[i]
- (1-param_momentum)*param_wd*weight_data[i]
- (1-param_momentum)*param_rescale_grad*grad_data[i];
}
KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
+ (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
}
};

template<typename xpu>
inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
});
}



} // namespace op
} // namespace mxnet

Expand Down
Loading