From f39148ccedfe2836f1c381dfd3f9663deb287c61 Mon Sep 17 00:00:00 2001 From: Anirudh Date: Thu, 14 Feb 2019 17:07:43 -0800 Subject: [PATCH] In-place updates for Nadam, Adadelta, Adamax and SGLD (#13960) --- python/mxnet/optimizer/optimizer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index a986f271c4b4..def2c958ede4 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -1091,8 +1091,9 @@ def update(self, index, weight, grad, state): grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) - weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr), shape=weight.shape, - dtype=weight.dtype, ctx=weight.context) + weight[:] += - lr/2 * (grad + wd * weight) + weight[:] += normal(0, math.sqrt(lr), shape=weight.shape, + dtype=weight.dtype, ctx=weight.context) @@ -1372,9 +1373,11 @@ def update(self, index, weight, grad, state): acc_g, acc_delta = state # update g, delta - acc_g[:] = self.rho * acc_g + (1. - self.rho) * grad * grad + acc_g[:] *= self.rho + acc_g[:] += (1. - self.rho) * grad * grad current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad - acc_delta[:] = self.rho * acc_delta + (1. - self.rho) * current_delta * current_delta + acc_delta[:] *= self.rho + acc_delta[:] += (1. - self.rho) * current_delta * current_delta # update weight weight[:] -= current_delta + wd * weight @@ -1507,7 +1510,8 @@ def update(self, index, weight, grad, state): # update m_t and u_t m_t, u_t = state - m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad + m_t[:] *= self.beta1 + m_t[:] += (1. - self.beta1) * grad u_t[:] = maximum(self.beta2 * u_t, NDabs(grad)) # update weight @@ -1570,8 +1574,10 @@ def update(self, index, weight, grad, state): # update m_t and v_t m_t, v_t = state - m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad - v_t[:] = self.beta2 * v_t + (1. - self.beta2) * grad * grad + m_t[:] *= self.beta1 + m_t[:] += (1. - self.beta1) * grad + v_t[:] *= self.beta2 + v_t[:] += (1. - self.beta2) * grad * grad grad_prime = grad / (1. - self.m_schedule) m_t_prime = m_t / (1. - m_schedule_next)