From 8154d80ef138b21c6f3c51ee24873ec591889eb4 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Thu, 27 Dec 2018 22:04:36 -0800 Subject: [PATCH] [FIX] Update BERTLayerNorm Implementation (#485) * fix layer norm * fix indentation * fix lint --- src/gluonnlp/model/bert.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index d722a03a5a..08d4301a04 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -22,7 +22,7 @@ 'BERTLayerNorm', 'bert_12_768_12', 'bert_24_1024_16'] import os -from mxnet.gluon import Block, HybridBlock +from mxnet.gluon import Block from mxnet.gluon import nn from mxnet.gluon.model_zoo import model_store import mxnet as mx @@ -34,35 +34,18 @@ # COMPONENTS # ############################################################################### -class BERTLayerNorm(HybridBlock): - """BERT style Layer Normalization. - - Epsilon is added inside the square root. +class BERTLayerNorm(nn.LayerNorm): + """BERT style Layer Normalization, where epsilon is added inside the square + root and set to 1e-12 by default. Inputs: - **data**: input tensor with arbitrary shape. - Outputs: - **out**: output tensor with the same shape as `data`. """ def __init__(self, epsilon=1e-12, in_channels=0, prefix=None, params=None): - super(BERTLayerNorm, self).__init__(prefix=prefix, params=params) - self.gamma = self.params.get('gamma', shape=(in_channels,), - allow_deferred_init=True) - self.beta = self.params.get('beta', shape=(in_channels,), - allow_deferred_init=True) - self._eps = epsilon - - def hybrid_forward(self, F, x, gamma, beta): # pylint: disable=arguments-differ - u = F.mean(x, -1, keepdims=True) - s = F.mean(F.broadcast_sub(x, u) ** 2, -1, keepdims=True) + self._eps - x = F.broadcast_div(F.broadcast_sub(x, u), s.sqrt()) - return F.broadcast_add(F.broadcast_mul(gamma, x), beta) - - def __repr__(self): - s = '{name}(' - in_channels = self.gamma.shape[0] - s += 'in_channels={0}, epsilon={1})'.format(in_channels, self._eps) - return s.format(name=self.__class__.__name__) + super(BERTLayerNorm, self).__init__(epsilon=epsilon, in_channels=in_channels, + prefix=prefix, params=params) + class BERTPositionwiseFFN(BasePositionwiseFFN): """Structure of the Positionwise Feed-Forward Neural Network for