From 5c17c79d4144c054c68e1b6793a7b22833cab8b9 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Mon, 29 Apr 2019 08:54:28 -0700 Subject: [PATCH] reuse sgd updates where convenient --- python/mxnet/optimizer/optimizer.py | 10 +-- src/operator/optimizer_op-inl.h | 87 ------------------------- src/operator/optimizer_op.cc | 37 ----------- src/operator/optimizer_op.cu | 6 -- tests/python/unittest/test_optimizer.py | 2 +- 5 files changed, 6 insertions(+), 136 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 60ee44d03b20..c2c1aa6a76f4 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -28,9 +28,9 @@ from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update, nag_update, nag_mom_update, mp_nag_update, - mp_nag_mom_update, multi_sgd_update, multi_sgd_mom_update, - multi_mp_sgd_update, multi_mp_sgd_mom_update) + signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update, + multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, + multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -1086,13 +1086,13 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False): if state is not None: nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) else: - nag_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) + sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) else: if state[0] is not None: mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd, **kwargs) else: - mp_nag_update(weight, grad, state[1], out=weight, + mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs) def update(self, index, weight, grad, state): diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index bc736aa1f818..50637a8e7b42 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1029,48 +1029,6 @@ struct NAGParam : public dmlc::Parameter { } }; -struct NAGKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, - const DType* weight_data, const DType* grad_data, - const DType param_clip_gradient, const DType param_lr, - const DType param_wd, const DType param_rescale_grad, - const OpReqType req) { - if (param_clip_gradient >= 0.0f) { - KERNEL_ASSIGN(out_data[i], req, - weight_data[i] - - param_lr * (mshadow_op::clip::Map(param_rescale_grad*grad_data[i], - param_clip_gradient) - + param_wd*weight_data[i])); - } else { - KERNEL_ASSIGN(out_data[i], req, - weight_data[i] - - param_lr * (param_rescale_grad*grad_data[i] - + (param_wd*weight_data[i]))); - } - } -}; - -template -inline void NAGUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mxnet_op; - const NAGParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); - Kernel::Launch(s, weight.shape_.Size(), out.dptr_, - weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), - static_cast(param.lr), static_cast(param.wd), - static_cast(param.rescale_grad), req[0]); - }); -} - struct NAGMomParam : public dmlc::Parameter { float lr; float momentum; @@ -1150,51 +1108,6 @@ inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs, }); } -struct MP_NAGKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, - const DType* weight_data, const DType* grad_data, - float* weight32, const float param_clip_gradient, - const float param_lr, const float param_wd, - const float param_rescale_grad, - const OpReqType req) { - if (param_clip_gradient >= 0.0f) { - float w = weight32[i]; - w = w - param_lr * (mshadow_op::clip::Map(param_rescale_grad - *static_cast(grad_data[i]), param_clip_gradient) - + param_wd*w); - weight32[i] = w; - KERNEL_ASSIGN(out_data[i], req, (DType)w); - } else { - float w = weight32[i]; - w = w - param_lr * (param_rescale_grad - *static_cast(grad_data[i]) + (param_wd*w)); - weight32[i] = w; - KERNEL_ASSIGN(out_data[i], req, (DType)w); - } - } -}; - -template -inline void MP_NAGUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mxnet_op; - const NAGParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor weight32 = inputs[2].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); - Kernel::Launch(s, weight.shape_.Size(), out.dptr_, - weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, - param.lr, param.wd, param.rescale_grad, req[0]); - }); -} - struct MP_NAGMomKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index e77bd416e37a..01410863640f 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -707,24 +707,6 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a .add_arguments(AdamParam::__FIELDS__()); -NNVM_REGISTER_OP(nag_update) -.describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. -It updates the weights using the following formula, - -weight = weight - (lr * (grad + wd * weight)) - -)code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<2, 1>) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FCompute", NAGUpdate) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_arguments(NAGParam::__FIELDS__()); - - NNVM_REGISTER_OP(nag_mom_update) .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer. It updates the weights using the following formula, @@ -756,25 +738,6 @@ Where .add_arguments(NAGMomParam::__FIELDS__()); -NNVM_REGISTER_OP(mp_nag_update) -.describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. -)code" ADD_FILELINE) -.set_num_inputs(3) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", MP_InferType<2, 1, 3>) -.set_attr("FCompute", MP_NAGUpdate) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2}; - }) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "gradient") -.add_argument("weight32", "NDArray-or-Symbol", "Weight32") -.add_arguments(NAGParam::__FIELDS__()); - - NNVM_REGISTER_OP(mp_nag_mom_update) .describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer. )code" ADD_FILELINE) diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 5d361b7e528d..2c72462de016 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -251,15 +251,9 @@ NNVM_REGISTER_OP(multi_mp_sgd_update) NNVM_REGISTER_OP(multi_mp_sgd_mom_update) .set_attr("FCompute", MultiSGDMomUpdate); -NNVM_REGISTER_OP(nag_update) -.set_attr("FCompute", NAGUpdate); - NNVM_REGISTER_OP(nag_mom_update) .set_attr("FCompute", NAGMomUpdate); -NNVM_REGISTER_OP(mp_nag_update) -.set_attr("FCompute", MP_NAGUpdate); - NNVM_REGISTER_OP(mp_nag_mom_update) .set_attr("FCompute", MP_NAGMomUpdate); diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index e151cfde2306..3e6cdd0997ce 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -425,7 +425,7 @@ def test_nag(): if (dtype == np.float16 and ('multi_precision' not in kwarg or not kwarg['multi_precision'])): continue - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-3, atol=1e-4) #SGLD class PySGLD(mx.optimizer.Optimizer):