From 059a05549e7892007a9490e9ca2987d22ddec816 Mon Sep 17 00:00:00 2001 From: khaotik Date: Fri, 30 Apr 2021 22:46:04 +0800 Subject: [PATCH] [FEATURE] AdaBelief operator (#20065) * copycat from adamw to adabelief * fix py lint * fix py lint #2 * fix cpp lint * add adabelief to amp list --- python/mxnet/amp/lists/symbol_fp16.py | 4 + python/mxnet/ndarray/contrib.py | 51 ++ python/mxnet/optimizer/__init__.py | 12 +- python/mxnet/optimizer/adabelief.py | 231 ++++++++ src/operator/contrib/adabelief-inl.h | 508 ++++++++++++++++++ src/operator/contrib/adabelief.cc | 261 +++++++++ src/operator/contrib/adabelief.cu | 57 ++ src/operator/contrib/adamw-inl.h | 10 +- src/operator/contrib/adamw.cc | 13 +- src/operator/contrib/adamw.cu | 10 +- .../python/unittest/test_contrib_optimizer.py | 147 +++-- tests/python/unittest/test_optimizer.py | 25 + 12 files changed, 1256 insertions(+), 73 deletions(-) create mode 100644 python/mxnet/optimizer/adabelief.py create mode 100644 src/operator/contrib/adabelief-inl.h create mode 100644 src/operator/contrib/adabelief.cc create mode 100644 src/operator/contrib/adabelief.cu diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index f78d32d6193a..6359384489c4 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -82,6 +82,7 @@ '_FusedOpHelper', '_FusedOpOutHelper', '_NoGradient', + '_adabelief_update', '_adamw_update', '_arange', '_cond', @@ -153,11 +154,14 @@ '_minimum_scalar', '_minus_scalar', '_mod_scalar', + '_mp_adabelief_update', '_mp_adamw_update', '_mul_scalar', + '_multi_adabelief_update', '_multi_adamw_update', '_multi_lamb_update', '_multi_lans_update', + '_multi_mp_adabelief_update', '_multi_mp_adamw_update', '_multi_mp_lamb_update', '_multi_mp_lans_update', diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 204d69465a81..ed70f8ccfc6e 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -679,6 +679,57 @@ def multi_mp_lamb_update(weights, grads, mean, var, weights32, step_count, wds=wds, **kwargs) +def adabelief_update(weight, grad, mean, var, rescale_grad, lr, eta, beta1=0.9, beta2=0.999, + epsilon=1e-8, wd=0, clip_gradient=-1, out=None, name=None, **kwargs): + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context) + return ndarray._internal._adabelief_update(weight=weight, grad=grad, mean=mean, var=var, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def mp_adabelief_update(weight, grad, mean, var, weight32, rescale_grad, lr, eta, beta1=0.9, + beta2=0.999, epsilon=1e-8, wd=0, clip_gradient=-1, out=None, + name=None, **kwargs): + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weight.context) + return ndarray._internal._mp_adabelief_update(weight=weight, grad=grad, mean=mean, var=var, + weight32=weight32, + rescale_grad=rescale_grad, lr=lr, eta=eta, + beta1=beta1, beta2=beta2, epsilon=epsilon, + wd=wd, clip_gradient=clip_gradient, out=out, + name=name, **kwargs) + +def multi_adabelief_update(weights, grads, mean, var, rescale_grad, lrs, wds, etas, + out=None, name=None, size=0, **kwargs): + if not size: + size = len(weights) + + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context) + temp_list = _flatten_list(zip(weights, grads, mean, var)) + [rescale_grad] + return ndarray._internal._multi_adabelief_update(*temp_list, + out=out, + num_weights=size, + lrs=lrs, + wds=wds, + etas=etas, + name=name, + **kwargs) + +def multi_mp_adabelief_update(weights, grads, mean, var, weights32, rescale_grad, lrs, wds, etas, + out=None, name=None, size=0, **kwargs): + if not size: + size = len(weights) + + rescale_grad = _get_rescale_grad(rescale_grad, ctx=weights[0].context) + temp_list = _flatten_list(zip(weights, grads, mean, var, weights32)) + [rescale_grad] + return ndarray._internal._multi_mp_adabelief_update(*temp_list, + out=out, + num_weights=size, + lrs=lrs, + wds=wds, + etas=etas, + name=name, + **kwargs) def multi_lans_update(weights, grads, mean, var, step_count, lrs, wds, out=None, num_tensors=0, **kwargs): diff --git a/python/mxnet/optimizer/__init__.py b/python/mxnet/optimizer/__init__.py index 9bf0c1f72af4..fba34a3ed015 100644 --- a/python/mxnet/optimizer/__init__.py +++ b/python/mxnet/optimizer/__init__.py @@ -19,8 +19,11 @@ from . import (optimizer, contrib, updater, utils, sgd, sgld, signum, dcasgd, nag, adagrad, adadelta, adam, adamax, nadam, ftrl, - ftml, lars, lamb, rmsprop, lans, adamW) + ftml, lars, lamb, rmsprop, lans, adamW, + adabelief) # pylint: disable=wildcard-import +from .adabelief import * + from .adamW import * from .optimizer import * @@ -62,6 +65,7 @@ from .lans import * __all__ = optimizer.__all__ + updater.__all__ + ['contrib'] + sgd.__all__ + sgld.__all__ \ - + signum.__all__ + dcasgd.__all__ + nag.__all__ + adagrad.__all__ + adadelta.__all__ \ - + adam.__all__ + adamax.__all__ + nadam.__all__ + ftrl.__all__ + ftml.__all__ \ - + lars.__all__ + lamb.__all__ + rmsprop.__all__ + lans.__all__ + + signum.__all__ + dcasgd.__all__ + nag.__all__ + adabelief.__all__ \ + + adagrad.__all__ + adadelta.__all__ + adam.__all__ + adamax.__all__ \ + + nadam.__all__ + ftrl.__all__ + ftml.__all__ + lars.__all__ \ + + lamb.__all__ + rmsprop.__all__ + lans.__all__ diff --git a/python/mxnet/optimizer/adabelief.py b/python/mxnet/optimizer/adabelief.py new file mode 100644 index 000000000000..c224ebf46ee5 --- /dev/null +++ b/python/mxnet/optimizer/adabelief.py @@ -0,0 +1,231 @@ +# coding: utf-8 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AdaBelief optimizer.""" +import math +import os +import numpy as np +from .optimizer import Optimizer, register +from ..ndarray import (zeros, clip, sqrt, square, full, NDArray) +from ..ndarray.contrib import mp_adabelief_update, adabelief_update,\ + multi_mp_adabelief_update, multi_adabelief_update + + +__all__ = ['AdaBelief'] + + +@register +class AdaBelief(Optimizer): + """The AdaBelief optimizer. + + This class implements the optimizer described in *Adapting Stepsizes by the Belief in Observed Gradients*, + available at https://arxiv.org/pdf/2010.07468.pdf. + + Updates are applied by:: + + grad = clip(grad * rescale_grad, clip_gradient) + wd * w + m = beta1 * m + (1 - beta1) * grad + s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon + lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t) + w = w - lr * (m / (sqrt(s) + epsilon)) + + + Also, we can turn off the bias correction term and the updates are as follows:: + + grad = clip(grad * rescale_grad, clip_gradient) + wd * w + m = beta1 * m + (1 - beta1) * grad + s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon + lr = learning_rate + w = w - lr * (m / (sqrt(s) + epsilon)) + + This optimizer accepts the following parameters in addition to those accepted + by :class:`.Optimizer`. + + Parameters + ---------- + learning_rate : float, default 0.001 + The initial learning rate. If None, the optimization will use the + learning rate from ``lr_scheduler``. If not None, it will overwrite + the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` + is also None, then it will be set to 0.01 by default. + beta1 : float, default 0.9 + Exponential decay rate for the first moment estimates. + beta2 : float, default 0.999 + Exponential decay rate for the second moment estimates. + epsilon : float, default 1e-6 + Small value to avoid division by 0. + correct_bias : bool, default True + Can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). + Default True. + use_fused_step : bool, default True + Whether or not to use fused kernels for optimizer. + When use_fused_step=False, step is called, + otherwise, fused_step is called. + """ + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, + correct_bias=True, use_fused_step=True, **kwargs): + super().__init__(use_fused_step=use_fused_step, + learning_rate=learning_rate, + **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.correct_bias = correct_bias + self.aggregate_num = max(1, min(50, + int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', '4')))) + + def create_state(self, index, weight): + """state creation function.""" + return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean + zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance + + def step(self, indices, weights, grads, states): + """Perform an optimization step using gradients and states. + + Parameters + ---------- + indices : list of int + List of unique indices of the parameters into the individual learning rates + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` + and `set_wd_mult()`, respectively. + weights : list of NDArray + List of parameters to be updated. + grads : list of NDArray + List of gradients of the objective with respect to this parameter. + states : List of any obj + List of state returned by `create_state()`. + """ + for index, weight, grad, state in zip(indices, weights, grads, states): + self._update_count(index) + lr = self._get_lr(index) + wd = self._get_wd(index) + eps = self.epsilon + t = self._index_update_count[index] + + # preprocess grad + grad *= self.rescale_grad + grad += wd * weight + if self.clip_gradient is not None: + grad = clip(grad, -self.clip_gradient, self.clip_gradient) + if self.correct_bias: + coef1 = 1. - self.beta1**t + coef2 = 1. - self.beta2**t + lr *= math.sqrt(coef2) / coef1 + + # update mean and var + mean, var = state + mean[:] *= self.beta1 + mean[:] += (1. - self.beta1) * grad + var[:] *= self.beta2 + var[:] += (1. - self.beta2) * square(grad - mean) + var[:] += eps + + # update weight + d = mean / (sqrt(var) + eps) + weight[:] -= lr * d + + def fused_step(self, indices, weights, grads, states): + """Perform a fused optimization step using gradients and states. + Fused kernel is used for update. + + Parameters + ---------- + indices : list of int + List of unique indices of the parameters into the individual learning rates + and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` + and `set_wd_mult()`, respectively. + weights : list of NDArray + List of parameters to be updated. + grads : list of NDArray + List of gradients of the objective with respect to this parameter. + states : List of any obj + List of state returned by `create_state()`. + """ + multi_precision = self.multi_precision and weights[0].dtype == np.float16 + aggregate = self.aggregate_num > 1 + if not isinstance(indices, (tuple, list)): + indices = [indices] + weights = [weights] + grads = [grads] + states = [states] + for w_i, g_i in zip(weights, grads): + assert(isinstance(w_i, NDArray)) + assert(isinstance(g_i, NDArray)) + aggregate = (aggregate and + w_i.stype == 'default' and + g_i.stype == 'default') + self._update_count(indices) + lrs = self._get_lrs(indices) + wds = self._get_wds(indices) + if self.correct_bias: + new_lrs = [] + for idx, lr in zip(indices, lrs): + t = self._index_update_count[idx] + coef1 = 1. - self.beta1 ** t + coef2 = 1. - self.beta2 ** t + new_lrs.append(lr * math.sqrt(coef2) / coef1) + lrs = new_lrs + if not isinstance(self.rescale_grad, NDArray): + self.rescale_grad = full(shape=(1,), val=self.rescale_grad, ctx=weights[0].context) + else: + self.rescale_grad = self.rescale_grad.as_in_context(weights[0].context) + kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, + 'rescale_grad': self.rescale_grad} + if self.clip_gradient: + kwargs['clip_gradient'] = self.clip_gradient + + if aggregate: + current_index = 0 + while current_index < len(indices): + sidx = current_index + eidx = min(current_index + self.aggregate_num, len(indices)) + if not multi_precision: + mean, var = list(zip(*states[sidx:eidx])) + multi_adabelief_update(weights[sidx:eidx], grads[sidx:eidx], + mean, var, + out=weights[sidx:eidx], + size=len(weights[sidx:eidx]), + lrs=list(np.ones(len(weights[sidx:eidx]))), + wds=wds[sidx:eidx], + etas=lrs[sidx:eidx], + **kwargs) + else: + mean_var = list(zip(*states[sidx:eidx]))[0] + tmean_var = list(zip(*mean_var)) + mean = tmean_var[0] + var = tmean_var[1] + multi_mp_adabelief_update(weights[sidx:eidx], + grads[sidx:eidx], + mean, var, + list(zip(*states[sidx:eidx]))[1], + out=weights[sidx:eidx], + size=len(weights[sidx:eidx]), + lrs=list(np.ones(len(weights[sidx:eidx]))), + wds=wds[sidx:eidx], + etas=lrs[sidx:eidx], + **kwargs) + current_index += self.aggregate_num + else: + for w_i, g_i, s_i, lr, wd in zip(weights, grads, states, lrs, wds): + if not multi_precision: + mean, var = s_i + adabelief_update(w_i, g_i, mean, var, out=w_i, + lr=1, wd=wd, eta=lr, **kwargs) + else: + mean, var = s_i[0] + mp_adabelief_update(w_i, g_i, mean, var, s_i[1], out=w_i, + lr=1, wd=wd, eta=lr, **kwargs) diff --git a/src/operator/contrib/adabelief-inl.h b/src/operator/contrib/adabelief-inl.h new file mode 100644 index 000000000000..2f282158e4ce --- /dev/null +++ b/src/operator/contrib/adabelief-inl.h @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file adabelief-inl.h + * \brief Optimizer operators + * \author khaotik + */ +#ifndef MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_ +#include +#include +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { +namespace adabelief { + +struct AdaBeliefParam : public dmlc::Parameter { + float lr; + float beta1; + float beta2; + float epsilon; + float wd; + float eta; + float clip_gradient; + DMLC_DECLARE_PARAMETER(AdaBeliefParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-8f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wd) + .set_default(0.0f) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(eta) + .describe("Learning rate schedule multiplier"); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + } +}; + +// rescale_grad is a reserved argument at position -1. Example: +// n_in = 2: weight, grad (fp16) +// n_out = 1: weight (fp16) +// total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) +template +inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); + // TODO(@reminisce): change "none" behavior in ElemwiseAttr + return ElemwiseAttr( + attrs, in_attrs, out_attrs, mxnet::TShape()); +} + +template +inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + for (int i = n_in; i < total_in; ++i) { + TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32); + } + return ElemwiseAttr( + attrs, in_attrs, out_attrs, -1); +} + +template +struct MPAdaBeliefKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, + float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, + const float param_clip_gradient, const float param_beta1, const float param_beta2, + const float param_eta, const float param_lr, const float param_wd, + const float param_rescale_grad, const float param_epsilon) { + float w = weight32[i]; + float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + scaled_grad += param_wd * w; + if (param_clip_gradient >= 0.f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient); + + const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad; + const float adj = mshadow_op::square::Map(scaled_grad - mean); + const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon; + + w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)); + mean_data[i] = mean; + var_data[i] = var; + weight32[i] = w; + KERNEL_ASSIGN(out_data[i], req, w); + } +}; + +template +struct MPAdaBeliefUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mxnet_op; + const auto& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor weight32 = inputs[4].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_, + weight.dptr_, grad.dptr_, weight32.dptr_, + param.clip_gradient, param.beta1, param.beta2, param.eta, + param.lr, param.wd, rescale_grad, param.epsilon); + }); + }); + } +}; + +/* + * \brief adabelief update. + * + */ +template +struct AdaBeliefUpdate { + static inline void Forward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const auto ¶m = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const Tensor &weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + + grad = scalar(rescale_grad) * grad + scalar(param.wd) * weight; + if (param.clip_gradient >= 0.0f) + grad = F(grad, DType(param.clip_gradient)); + + mean = scalar(param.beta1) * mean + scalar(1.f-param.beta1) * grad; + var = scalar(param.beta2) * var + + scalar(1.f-param.beta2) * F(grad - mean) + + scalar(param.epsilon); + + Assign(out, req[0], + weight - + scalar(param.eta) * (scalar(param.lr) * + mean / (F(var) + scalar(param.epsilon)))); + }); + } +}; + +//// +// Multiple gradients in single kernel +//// +struct MultiAdaBeliefParam : public dmlc::Parameter { + mxnet::Tuple lrs; + mxnet::Tuple wds; + mxnet::Tuple etas; + float beta1; + float beta2; + float epsilon; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiAdaBeliefParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates"); + DMLC_DECLARE_FIELD(beta1) + .set_default(0.9f) + .describe("The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2) + .set_default(0.999f) + .describe("The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon) + .set_default(1e-8f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(etas) + .describe("Learning rates schedule multiplier"); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + + +template +inline bool MP_MultiAdaBelief_InferShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + + // Learning rates + CHECK_EQ(param.lrs.ndim(), param.num_weights) + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " + << param.num_weights << ", and got " << param.lrs.ndim(); + // Weight decays + CHECK_EQ(param.wds.ndim(), param.num_weights) + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " + << param.num_weights << ", and got " << param.wds.ndim(); + // Learning rates schedule multiplier + CHECK_EQ(param.etas.ndim(), param.num_weights) + << "Number of learning rates schedule multiplier is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates schedule multiplier: " + << param.num_weights << ", and got " << param.lrs.ndim(); + + // Weights, gradients, mean and variance + for (int i = 0; i < param.num_weights; ++i) { + mxnet::ShapeVector input_vec; + mxnet::ShapeVector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + + SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights*input_stride, mxnet::TShape()); + return all_inferred; +} + +template +inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + + // Weights, gradients, + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < input_stride - 2 - num_fp32_inputs; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType(attrs, &input_vec, &output_vec); + } + // mean, var + for (int i = 0; i < param.num_weights; ++i) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i +2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i +3, mshadow::kFloat32); + } + + // master copies of weights + for (int i = 0; i < param.num_weights; ++i) { + for (int j = 0; j < num_fp32_inputs; ++j) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); + } + } + + TYPE_ASSIGN_CHECK(input_types, param.num_weights*input_stride, mshadow::kFloat32); + return all_inferred; +} + + +template +class _type_identity { + public: + using type = T; +}; + + +template +class _single_precision { + public: + using type = float; +}; + +template +struct MultiKernelParam { + static const int N = 50; + int count; + size_t max_size; + size_t sizes[N]; + DType* weights[N]; + DType* grad_data[N]; + MPDType* mean_data[N]; + MPDType* var_data[N]; + MPDType* weights32[N]; + DType* out_data[N]; + MPDType clip_gradient; + MPDType beta1; + MPDType beta2; + MPDType etas[N]; + MPDType lrs[N]; + MPDType wds[N]; + MPDType epsilon; +}; + +template +struct MultiMPAdaBeliefKernel { + template + MSHADOW_XINLINE static void Map(int i, const MultiKernelParam& param, + const OpReqType req, const float rescale_grad) { + for (int index = 0; index < param.count; ++index) { + if ((size_t)i < param.sizes[index]) { + MPDType w = has_mixed_precision ? param.weights32[index][i]: + MPDType(param.weights[index][i]); + MPDType scaled_grad = static_cast(rescale_grad)* + static_cast(param.grad_data[index][i]); + + scaled_grad += param.wds[index] * w; + if (param.clip_gradient >= 0.f) + scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient); + + const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad; + const auto adj = mshadow_op::square::Map(mean - scaled_grad); + const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon; + + param.mean_data[index][i] = mean; + param.var_data[index][i] = var; + w = w - param.etas[index] * (param.lrs[index] * + mean / (mshadow_op::square_root::Map(var) + param.epsilon)); + if (has_mixed_precision) + param.weights32[index][i] = w; + + KERNEL_ASSIGN(param.out_data[index][i], req, w); + } + } + } +}; + +template +void FillMultiKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + MultiKernelParam *pParam) { + const ParamType& p = nnvm::get(attrs.parsed); + mxnet_op::Stream* s = ctx.get_stream(); + pParam->clip_gradient = p.clip_gradient; + pParam->beta1 = p.beta1; + pParam->beta2 = p.beta2; + + pParam->epsilon = p.epsilon; + + pParam->count = p.num_weights; + pParam->max_size = 0; + constexpr bool isSame = std::is_same::value; + for (int i = 0; i < pParam->count; ++i) { + const auto idx = i * input_stride; + pParam->sizes[i] = inputs[idx].shape_.Size(); + if (pParam->max_size < pParam->sizes[i]) + pParam->max_size = pParam->sizes[i]; + + pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; + pParam->grad_data[i] = inputs[idx + 1].FlatTo2D(s).dptr_; + pParam->mean_data[i] = inputs[idx + 2].FlatTo2D(s).dptr_; + pParam->var_data[i] = inputs[idx + 3].FlatTo2D(s).dptr_; + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!isSame) + pParam->weights32[i] = inputs[idx + input_stride - 1].FlatTo2D(s).dptr_; + + pParam->out_data[i] = outputs[i].FlatTo2D(s).dptr_; + } + memcpy(pParam->etas, p.etas.begin(), pParam->count * sizeof(p.etas[0])); + memcpy(pParam->lrs, p.lrs.begin(), pParam->count * sizeof(p.lrs[0])); + memcpy(pParam->wds, p.wds.begin(), pParam->count * sizeof(p.wds[0])); +} + +template class MPTypeChooser, int input_stride> +static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs, + const float rescale_grad) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiKernelParam param; + FillMultiKernelParam + (attrs, ctx, inputs, outputs, ¶m); + + Kernel::value>, xpu>:: + Launch(s, param.max_size, param, req[0], rescale_grad); + }); +} + +template +void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); + +template +bool PrepareInputBlobs(const OpContext &ctx, + const std::vector &inputs, + std::vector *inputs_wo_scale, + float *pScalef) { + const size_t num_in = inputs.size() - 1; + adabelief::GetScaleFloat(ctx.get_stream(), inputs[num_in], pScalef); + if (!std::isfinite(*pScalef) || *pScalef == 0) + return false; + + inputs_wo_scale->reserve(num_in); + for (size_t i = 0; i < num_in; i++) + inputs_wo_scale->emplace_back(inputs[i]); + + return true; +} + +template +inline void MPUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + std::vector inputs_wo_scale; + float scalef; + if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) + return; + + F::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); +} + +template +inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + std::vector inputs_wo_scale; + float scalef; + if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) + return; + + if (!MP) + MultiAdaBeliefUpdate + (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + else + MultiAdaBeliefUpdate + (attrs, ctx, inputs_wo_scale, req, outputs, scalef); +} + +} // namespace adabelief +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ADABELIEF_INL_H_ diff --git a/src/operator/contrib/adabelief.cc b/src/operator/contrib/adabelief.cc new file mode 100644 index 000000000000..06be7480f8a7 --- /dev/null +++ b/src/operator/contrib/adabelief.cc @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file adabelief.cc + * \brief Optimizer operators + * \author khaotik + */ +#include "./adabelief-inl.h" + +namespace mxnet { +namespace op { +namespace adabelief { + +DMLC_REGISTER_PARAMETER(AdaBeliefParam); +DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam); + +NNVM_REGISTER_OP(_mp_adabelief_update) +.describe(R"code(Update function for multi-precision AdaBelief optimizer. + +AdaBelief is seen as a modification of Adam with a different variance +estimator. + +Adam update consists of the following steps, where g represents gradient and m, s +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1}) + w * wd \\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + s = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(s) + epsilon)) + +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. +)code" ADD_FILELINE) +.set_num_inputs(6) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MPUpdateInferShape<2, 1, 6>) +.set_attr("FInferType", MPUpdateInferType<2, 1, 6>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3, 4}; + }) +.set_attr("FCompute", MPUpdate>) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_argument("weight32", "NDArray-or-Symbol", "Weight32") +.add_argument("rescale_grad", "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") +.add_arguments(AdaBeliefParam::__FIELDS__()); + +NNVM_REGISTER_OP(_adabelief_update) +.describe(R"code(Update function for AdaBelief optimizer. + +AdaBelief is seen as a modification of Adam with a different variance +estimator. + +Adam update consists of the following steps, where g represents gradient and m, s +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1}) + w * wd \\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + s = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(s) + epsilon)) + +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. +))code" ADD_FILELINE) +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MPUpdateInferShape<4, 1, 5>) +.set_attr("FInferType", MPUpdateInferType<4, 1, 5>) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) +.set_attr("FCompute", MPUpdate>) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("grad", "NDArray-or-Symbol", "Gradient") +.add_argument("mean", "NDArray-or-Symbol", "Moving mean") +.add_argument("var", "NDArray-or-Symbol", "Moving variance") +.add_argument("rescale_grad", "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") +.add_arguments(AdaBeliefParam::__FIELDS__()); + +template<> +void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { + MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, + *pScalef = static_cast(*scale_blob.dptr()); + ) +} + +static std::vector +ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + const auto idx = std::to_string(i); + for (size_t j = 0; j < nParams; ++j) + ret.push_back(std::string(pName[i]) + idx); + } + + return ret; +} + +inline uint32_t num_weights(const nnvm::NodeAttrs& attrs) { + return static_cast(dmlc::get(attrs.parsed).num_weights); +} + +NNVM_REGISTER_OP(_multi_adabelief_update) +.describe(R"code(Update function for AdaBelief optimizer. + +AdaBelief is seen as a modification of Adam with a different variance +estimator. + +Adam update consists of the following steps, where g represents gradient and m, s +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1}) + w * wd \\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + s = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(s) + epsilon)) + +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. +))code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return num_weights(attrs) * 4 + 1; + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return num_weights(attrs); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MP_MultiAdaBelief_InferShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "rescale_grad_"}; + return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); + }) +// mutable: mean, var +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) + +.set_attr("FCompute", multiMPUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "data") +.add_arguments(MultiAdaBeliefParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_multi_mp_adabelief_update) +.describe(R"code(Update function for multi-precision AdaBelief optimizer. + +AdaBelief is seen as a modification of Adam with a different variance +estimator. + +Adam update consists of the following steps, where g represents gradient and m, s +are 1st and 2nd order moment estimates (mean and variance). + +.. math:: + + g_t = \nabla J(W_{t-1}) + w * wd \\ + m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ + s_t = \beta_2 v_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon\\ + W_t = W_{t-1} - \eta_t (\alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }) + +It updates the weights using:: + + m = beta1*m + (1-beta1)*grad + s = beta2*v + (1-beta2)*(grad**2) + w -= eta * (learning_rate * m / (sqrt(s) + epsilon)) + +Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, +the update is skipped. +))code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + return num_weights(attrs) * 5 + 1; + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return num_weights(attrs); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MP_MultiAdaBelief_InferShape) +.set_attr("FInferType", MP_MultiAdaBelief_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"}; + return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); + }) +// mutable: mean, var, weights32 +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 5 + 2); + ret.push_back(i * 5 + 3); + ret.push_back(i * 5 + 4); + } + return ret; + }) + +.set_attr("FCompute", multiMPUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "data") +.add_arguments(MultiAdaBeliefParam::__FIELDS__()); + +} // namespace adabelief +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/adabelief.cu b/src/operator/contrib/adabelief.cu new file mode 100644 index 000000000000..e64dcb4ca006 --- /dev/null +++ b/src/operator/contrib/adabelief.cu @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file adabelief.cu + * \brief Optimizer operators + * \author khaotik + */ +#include "./adabelief-inl.h" + +namespace mxnet { +namespace op { +namespace adabelief { +template<> +void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { + MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { + DType scale = 0; + cudaStream_t stream = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync(&scale, scale_blob.dptr(), sizeof(DType), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + *pScalef = static_cast(scale); + }) +} +} // namespace adabelief + +NNVM_REGISTER_OP(_adabelief_update) +.set_attr("FCompute", adabelief::MPUpdate>); + +NNVM_REGISTER_OP(_mp_adabelief_update) +.set_attr("FCompute", adabelief::MPUpdate>); + +NNVM_REGISTER_OP(_multi_adabelief_update) +.set_attr("FCompute", adabelief::multiMPUpdate); + +NNVM_REGISTER_OP(_multi_mp_adabelief_update) +.set_attr("FCompute", adabelief::multiMPUpdate); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 6f483333314b..56c5ea227862 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -32,6 +32,7 @@ namespace mxnet { namespace op { +namespace adamw { struct AdamWParam : public dmlc::Parameter { float lr; @@ -114,7 +115,7 @@ struct MPAdamWKernel { float var = var_data[i] = param_beta2 * var_data[i] + (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad); - w = w - param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + param_wd * w); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); @@ -349,7 +350,7 @@ template struct MultiMPAdamWKernel { template MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam& param, - const OpReqType req, const float rescale_grad){ + const OpReqType req, const float rescale_grad) { for (int index = 0; index < param.count; ++index) { if ((size_t)i < param.sizes[index]) { MPDType w = has_mixed_precision ? param.weights32[index][i]: @@ -442,7 +443,7 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs, } template -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); +static void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); template bool PrepareInputBlobs(const OpContext &ctx, @@ -450,7 +451,7 @@ bool PrepareInputBlobs(const OpContext &ctx, std::vector *inputs_wo_scale, float *pScalef) { const size_t num_in = inputs.size() - 1; - GetScaleFloat(ctx.get_stream(), inputs[num_in], pScalef); + adamw::GetScaleFloat(ctx.get_stream(), inputs[num_in], pScalef); if (!std::isfinite(*pScalef) || *pScalef == 0) return false; @@ -494,6 +495,7 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, (attrs, ctx, inputs_wo_scale, req, outputs, scalef); } +} // namespace adamw } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index effae5c42d19..e66250200e54 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -27,6 +27,7 @@ namespace mxnet { namespace op { +namespace adamw { DMLC_REGISTER_PARAMETER(AdamWParam); DMLC_REGISTER_PARAMETER(MultiAdamWParam); @@ -65,7 +66,7 @@ the update is skipped. [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3, 4}; }) -.set_attr("FCompute", MPUpdate>) +.set_attr("FCompute", adamw::MPUpdate>) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") @@ -108,7 +109,7 @@ the update is skipped. [](const nnvm::NodeAttrs& attrs) { return std::vector{2, 3}; }) -.set_attr("FCompute", MPUpdate>) +.set_attr("FCompute", adamw::MPUpdate>) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") @@ -125,7 +126,8 @@ void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float ) } -std::vector ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { +static std::vector +ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { std::vector ret; for (uint32_t i = 0; i < num_args; ++i) { const auto idx = std::to_string(i); @@ -191,7 +193,7 @@ the update is skipped. return ret; }) -.set_attr("FCompute", multiMPUpdate) +.set_attr("FCompute", adamw::multiMPUpdate) .add_argument("data", "NDArray-or-Symbol[]", "data") .add_arguments(MultiAdamWParam::__FIELDS__()); @@ -248,10 +250,11 @@ the update is skipped. return ret; }) -.set_attr("FCompute", multiMPUpdate) +.set_attr("FCompute", adamw::multiMPUpdate) .add_argument("data", "NDArray-or-Symbol[]", "data") .add_arguments(MultiAdamWParam::__FIELDS__()); +} // namespace adamw } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index 2b0040e5f6ac..95fcffbd78e4 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -27,6 +27,7 @@ namespace mxnet { namespace op { +namespace adamw { template<> void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { @@ -41,16 +42,17 @@ void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float } NNVM_REGISTER_OP(_adamw_update) -.set_attr("FCompute", MPUpdate>); +.set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_mp_adamw_update) -.set_attr("FCompute", MPUpdate>); +.set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_multi_adamw_update) -.set_attr("FCompute", multiMPUpdate); +.set_attr("FCompute", adamw::multiMPUpdate); NNVM_REGISTER_OP(_multi_mp_adamw_update) -.set_attr("FCompute", multiMPUpdate); +.set_attr("FCompute", adamw::multiMPUpdate); +} // namespace adamw } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_optimizer.py b/tests/python/unittest/test_contrib_optimizer.py index f0fbb7b7aaec..b6f624f2a2d4 100644 --- a/tests/python/unittest/test_contrib_optimizer.py +++ b/tests/python/unittest/test_contrib_optimizer.py @@ -61,27 +61,27 @@ def test_group_adagrad(): dtype, g_stype='row_sparse') - -@xfail_when_nonstandard_decimal_separator -@pytest.mark.serial -def test_adamw(): - def get_refs(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1): - if clip_grad >= 0: - grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad) - - mean_ref = beta1*m + (1-beta1)*grad_rescale - v_ref = beta2*v + (1-beta2)*(grad_rescale**2) - weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd) - return mean_ref, v_ref, weight_ref - - def run_adamw_test(nElem=1, aggregate=False): - aggregate = aggregate or nElem > 1 +def _fn_noimpl(*args, **kwargs): + raise NotImplementedError() + +class _AdamLikeTestHelper: + fn_update = _fn_noimpl + fn_multi_update = _fn_noimpl + fn_mp_update = _fn_noimpl + fn_multi_mp_update = _fn_noimpl + @staticmethod + def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1): + '''Returns (mean_ref, v_ref, weight_ref)''' + raise NotImplementedError() + @classmethod + def run_test(cls, num_elem=1, aggregate=False): + aggregate = aggregate or num_elem > 1 rescale_factor = 10 eta, lr, wd, epsilon = 1, 1, 0.1, 1e-8 beta1, beta2 = 0.9, 0.999 clip_gradient = np.random.uniform(rescale_factor, rescale_factor) weight, grad, m, v, etas, lrs, wds, weight_ref = [], [], [], [], [], [], [], [] - for i in range(nElem): + for i in range(num_elem): shape = (np.random.randint(3, high=10), np.random.randint(3, high=10)) weight.append(mx.nd.random.uniform(shape=shape)) grad.append(mx.nd.random.uniform(-1.0, 1.0, shape=shape)) @@ -107,95 +107,130 @@ def run_adamw_test(nElem=1, aggregate=False): for rescaled_grad in tested_rescaled_grad: if aggregate: - mx.nd.contrib.multi_adamw_update(weight, grad, m, v, - rescaled_grad, out=weight, **kwargs) + cls.fn_multi_update(weight, grad, m, v, + rescaled_grad, out=weight, **kwargs) else: - mx.nd.contrib.adamw_update(weight[0], grad[0], m[0], v[0], - rescaled_grad, out=weight[0], **kwargs) - + cls.fn_update(weight[0], grad[0], m[0], v[0], + rescaled_grad, out=weight[0], **kwargs) # weights should remain unchanged - for j in range(nElem): + for j in range(num_elem): assert_almost_equal(weight_ref[j], weight[j]) - # Test 2: Same as Test 1 for multi-precision update weight_fp16, grad_fp16, weight_fp16_refs = [], [], [] - for i in range(nElem): + for i in range(num_elem): weight_fp16.append(weight[i].astype('float16')) grad_fp16.append(grad[i].astype('float16')) weight_fp16_refs.append(weight_fp16[i].copy()) for rescaled_grad in tested_grad: if aggregate: - mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, - rescaled_grad, out=weight_fp16, **kwargs) + cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight, + rescaled_grad, out=weight_fp16, **kwargs) else: - mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0], - rescaled_grad, out=weight_fp16[0], **kwargs) - + cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0], + rescaled_grad, out=weight_fp16[0], **kwargs) # weights should remain unchanged - for i in range(nElem): + for i in range(num_elem): assert_almost_equal(weight_ref[i], weight[i]) assert_almost_equal(weight_fp16_refs[i], weight_fp16[i]) - # Test 3: Reference normal update grad_rescale, weight_test, m_refs, v_refs, weight_refs = [], [], [], [], [] - for i in range(nElem): + for i in range(num_elem): grad_rescale.append(rescale_grad * grad[i]) - m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient) + m_ref, v_ref, weight_ref = cls.ref_impl( + m[i], v[i], weight[i], grad_rescale[i], + beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient) m_refs.append(m_ref) v_refs.append(v_ref) weight_refs.append(weight_ref) weight_test.append(weight[i].copy()) - # op normal update if aggregate: - mx.nd.contrib.multi_adamw_update(weight_test, grad, m, v, - rescale_grad, out=weight_test, **kwargs) + cls.fn_multi_update(weight_test, grad, m, v, + rescale_grad, out=weight_test, **kwargs) else: - mx.nd.contrib.adamw_update(weight_test[0], grad[0], m[0], v[0], - rescale_grad, out=weight_test[0], **kwargs) - + cls.fn_update(weight_test[0], grad[0], m[0], v[0], + rescale_grad, out=weight_test[0], **kwargs) # Compare results atol = 1e-4 if aggregate else 1e-5 rtol = 1e-4 if aggregate else None - for i in range(nElem): + for i in range(num_elem): assert_almost_equal(weight_refs[i], weight_test[i], rtol=rtol, atol=atol) assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol) assert_almost_equal(v_refs[i], v[i], atol=atol) - # Test 4: Reference normal multi-precision update grad_rescale, m_refs, v_refs, weight_refs, weight_fp16_refs = [], [], [], [], [] - for i in range(nElem): + for i in range(num_elem): grad_rescale.append(rescale_grad * grad_fp16[i].astype('float32')) - m_ref, v_ref, weight_ref = get_refs(m[i], v[i], weight[i], grad_rescale[i], beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient) + m_ref, v_ref, weight_ref = cls.ref_impl( + m[i], v[i], weight[i], grad_rescale[i], + beta1, beta2, lrs[i], etas[i], wds[i], epsilon, clip_gradient) m_refs.append(m_ref) v_refs.append(v_ref) weight_refs.append(weight_ref) weight_fp16_refs.append(weight_ref.astype('float16')) - # op normal multi-precision update if aggregate: - mx.nd.contrib.multi_mp_adamw_update(weight_fp16, grad_fp16, m, v, weight, - rescale_grad, out=weight_fp16, **kwargs) + cls.fn_multi_mp_update(weight_fp16, grad_fp16, m, v, weight, + rescale_grad, out=weight_fp16, **kwargs) else: - mx.nd.contrib.mp_adamw_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0], - rescale_grad, out=weight_fp16[0], **kwargs) - + cls.fn_mp_update(weight_fp16[0], grad_fp16[0], m[0], v[0], weight[0], + rescale_grad, out=weight_fp16[0], **kwargs) # Compare results - for i in range(nElem): + for i in range(num_elem): assert_almost_equal(m_refs[i], m[i], rtol=rtol, atol=atol) assert_almost_equal(v_refs[i], v[i], atol=atol) assert_almost_equal(weight_refs[i], weight[i], rtol=rtol, atol=atol) assert_almost_equal(weight_fp16_refs[i], weight_fp16[i], rtol=1e-3, atol=atol) - # Testing aggregated Adam update for one element - run_adamw_test(1, aggregate=True) + def __call__(self): + # Testing aggregated Adam update for one element + self.run_test(1, aggregate=True) + # Testing Adam update, if num_elem == 0, OR + # aggregated Adam update, if num_elem > 0 + for num_elem in reversed(range(6)): + self.run_test(num_elem+1) + +class _AdamWTestHelper(_AdamLikeTestHelper): + fn_update = mx.nd.contrib.adamw_update + fn_multi_update = mx.nd.contrib.multi_adamw_update + fn_mp_update = mx.nd.contrib.mp_adamw_update + fn_multi_mp_update = mx.nd.contrib.multi_mp_adamw_update + @staticmethod + def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1): + if clip_grad >= 0: + grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad) - # Testing Adam update, if nElem = 0, OR - # aggregated Adam update, if nElem > 0 - for nElem in range(6): - run_adamw_test(nElem+1) + mean_ref = beta1*m + (1.-beta1)*grad_rescale + v_ref = beta2*v + (1.-beta2)*(grad_rescale**2) + weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon) + weight * wd) + return mean_ref, v_ref, weight_ref + +class _AdaBeliefTestHelper(_AdamLikeTestHelper): + fn_update = mx.nd.contrib.adabelief_update + fn_multi_update = mx.nd.contrib.multi_adabelief_update + fn_mp_update = mx.nd.contrib.mp_adabelief_update + fn_multi_mp_update = mx.nd.contrib.multi_mp_adabelief_update + @staticmethod + def ref_impl(m, v, weight, grad_rescale, beta1, beta2, lr, eta, wd, epsilon, clip_grad=-1): + grad_rescale += wd * weight + if clip_grad >= 0: + grad_rescale = mx.nd.clip(grad_rescale, -clip_grad, clip_grad) + mean_ref = beta1*m + (1.-beta1)*grad_rescale + v_ref = beta2*v + (1.-beta2)*((grad_rescale-mean_ref)**2) + epsilon + weight_ref = weight - eta * (lr * mean_ref / (v_ref.sqrt() + epsilon)) + return mean_ref, v_ref, weight_ref + +@xfail_when_nonstandard_decimal_separator +@pytest.mark.serial +def test_adamw(): + _AdamWTestHelper()() + +@xfail_when_nonstandard_decimal_separator +@pytest.mark.serial +def test_adabelief(): + _AdaBeliefTestHelper()() diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 8927bcd1912e..7ccb8f18242c 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -923,6 +923,31 @@ def test_adamW(): opt2(use_fused_step=True, **kwarg), shapes, dtype, rtol=1e-3, atol=2e-3) +def test_adabelief(): + opt1 = mx.optimizer.AdaBelief + opt2 = mx.optimizer.AdaBelief + shapes = [(3, 4, 5), (10, 4), (7,)] + beta1_options = [{}, {'beta1': 0.5}, {'beta1': 0.7}] + beta2_options = [{}, {'beta2': 0.8}, {'beta2': 0.9}] + cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}] + rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] + wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] + mp_options = [{'multi_precision': False}, {'multi_precision': True}] + agg_options = [{'aggregate_num': 0}, {'aggregate_num': 1}, + {'aggregate_num': 4}, {'aggregate_num': np.inf}] + correct_bias_options = [{'correct_bias': True}, {'correct_bias': False}] + for dtype in [np.float16, np.float32]: + for params in itertools.product(beta1_options, beta2_options, cg_options, + rg_options, wd_options, mp_options, + agg_options, correct_bias_options): + kwarg = {k: v for param in params for k, v in param.items()} + if (dtype == np.float16 and ('multi_precision' not in kwarg or + not kwarg['multi_precision'])): + continue + compare_optimizer(opt1(use_fused_step=False, **kwarg), + opt2(use_fused_step=True, **kwarg), shapes, dtype, + rtol=1e-3, atol=2e-3) + def test_factor_scheduler(): base_lr = 1 step = 100