-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathnn_components.py
437 lines (367 loc) · 16.5 KB
/
nn_components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core NN components used in models.
"""
from typing import Any, Callable, Optional, Tuple, Union
from absl import logging
from flax import linen as nn
import gin
import jax
from jax import lax
from jax.nn import initializers
import jax.numpy as jnp
PRNGKey = Any
Array = jnp.ndarray
Shape = Tuple[int, ...]
Dtype = Union[jnp.dtype, str]
def scalar_initializer(x):
"""Like linen.zeros, but initializes a parameter to a scalar value."""
def init_fun(key, shape, dtype):
del key
return jnp.broadcast_to(jnp.array(x, dtype=dtype), shape)
return init_fun
def swish(x: Array) -> Array:
"""Swish function, which is very similar to gelu."""
return x * nn.sigmoid(x)
def soft_abs(x: Array) -> Array:
"""Soft version of absolute value, that is smoothly differentiable."""
return jnp.sqrt(jnp.square(x) + 1) - 1
def get_activation_function(fname: Optional[str]) -> Callable[[Array], Array]:
"""Get activation function from the specified string."""
if fname is None:
return lambda x: x
elif fname == "relu":
return nn.relu
elif fname == "swish":
return swish
elif fname == "sigmoid":
return nn.sigmoid
elif fname == "tanh":
return nn.tanh
else:
raise ValueError("Unknown activation function %s" % fname)
# Adapted from flax.linen.softmax.
def safe_softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
min_x: Optional[Array] = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
This version of softmax is intended for use with causal attention masks, and
safely covers the situation where all elements are masked out. If min_x is
not None, then probabability will be distributed between the values in x, and
min_x. If x >> min_x, then the probability allocated to min_x will be zero,
and this function will be the same as the usual softmax. However, if
x << min_x, (because all the values in x are masked out) then probability
will be allocated to min_x instead, and the probability allocated to x will
be 0. I.e., attention will attend to nothing if everything is masked out.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x: input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
min_x: the value of a minimum element which will be included in the
softmax sum. The value of min_x should be small when compared to the
expected values in x. If all of the values in x are smaller than
min_x, then probability will be allocated to the minimum element
instead, and the result of softmax will sum to less than 1.
Returns:
An array of the same shape as x.
"""
# Subtract maximum value in x for numerical stability, so that the exponent
# never exceeds numerical precision.
x_max = lax.stop_gradient(jnp.max(x, axis, initial=min_x, keepdims=True))
if min_x is not None:
min_x = jnp.asarray(min_x, dtype=x.dtype)
x_max = jnp.maximum(x_max, min_x)
unnormalized = jnp.exp(x - x_max)
x_sum = jnp.sum(unnormalized, axis=axis, keepdims=True)
if min_x is not None:
x_sum = x_sum + jnp.exp(min_x - x_max)
return unnormalized / x_sum
def dropout_multiplier_mask(rng, dropout_rate: float, shape: Shape,
dtype: Dtype):
"""Returns an array which can be multiplied by an input to perform dropout.
Args:
rng: A random number generator.
dropout_rate: The rate at which to drop.
shape: The shape of the output array.
dtype: The type of the output array.
Returns:
An array of given shape, where values are { 0.0, 1.0/keep_probibility. }.
"""
if dropout_rate <= 0.0:
return jnp.ones(shape, dtype=dtype)
logging.info("dropout mask: %s", shape)
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(rng, keep_prob, shape)
dropout_multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype))
return dropout_multiplier
def tiled_dropout(x: Array, shape: Shape, dropout_rate: float,
rng_function: Callable[[], jax.random.KeyArray],
deterministic: bool) -> Array:
"""Tiles a dropout mask over a larger array.
This will generate a smaller dropout mask of the given shape, and tile it
over a larger array, which reduces the computational cost and memory
associated with generating a large dropout mask.
Args:
x: The input array.
shape: The shape of the dropout mask to tile.
dropout_rate: The rate at which to drop.
rng_function: A function which returns a random number generator, e.g.
lambda. self.make_rng("dropout"). The function will not
be called if dropout is not enabled.
deterministic: If True, don't do dropout.
Returns:
An array of the same shape as x, with some values dropped out.
"""
if deterministic or dropout_rate <= 0.0:
return x
if x.ndim != len(shape):
raise ValueError("Shapes must have same number of dimensions %r, %r." %
(x.shape, shape))
for (xd, sd) in zip(x.shape, shape):
if (xd % sd) != 0:
raise ValueError("Incompatible shapes %r, %r" % (x.shape, shape))
# Get random number generator for dropout.
rng = rng_function()
repeats = [(1 if sd == 1 else xd // sd) for (xd, sd) in zip(x.shape, shape)]
logging.info("tiled dropout %r, tile: %r", x.shape, shape)
dtype = x.dtype
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(rng, keep_prob, shape)
keep = jnp.tile(keep, repeats)
keep = jnp.broadcast_to(keep, x.shape)
x_scaled = x / jnp.asarray(keep_prob, dtype=dtype)
return lax.select(keep, x_scaled, jnp.zeros_like(x, dtype=dtype))
@gin.configurable
class MLP(nn.Module):
"""Implements a multi-layer perceptron, with optional resnet or gate."""
# Arguments to module.
num_output_features: int # Length of output vectors.
# Gin configurable parameters.
num_layers: int = gin.REQUIRED # Number of layers in the MLP.
num_hidden_units: int = gin.REQUIRED # Length of hidden unit vectors.
hidden_activation: Optional[str] = "relu" # Hidden layer activation fn.
final_activation: Optional[str] = None # Final layer activation fn.
use_bias: bool = True # Use a bias in each dense layer.
gate_type: Optional[str] = None # { "residual", "bias", "full" }
initializer_scale: float = 1.0 # Scale of initial values.
dtype: Any = jnp.float32
def setup(self):
kernel_init = jax.nn.initializers.variance_scaling(
scale=self.initializer_scale, mode="fan_in",
distribution="truncated_normal")
assert self.num_layers > 0
hlayers = []
for i in range(0, self.num_layers - 1):
assert self.num_hidden_units > 0
hlayer = nn.Dense(self.num_hidden_units,
use_bias=self.use_bias,
kernel_init=kernel_init,
dtype=self.dtype,
name=f"hidden{i}")
hlayers.append(hlayer)
self.hidden_layers = hlayers
self.output_layer = nn.Dense(self.num_output_features,
use_bias=self.use_bias,
kernel_init=kernel_init,
dtype=self.dtype)
if self.gate_type is None or self.gate_type == "residual":
return
# We use a low but non-zero bias so that adafactor knows how to scale it.
gate_bias_init = jax.nn.initializers.normal(stddev=0.1)
# Also use a lower than normal kernel.
gate_kernel_init = jax.nn.initializers.variance_scaling(
scale=0.1, mode="fan_in", distribution="truncated_normal")
if self.gate_type == "bias":
self.gate_bias = self.param("gate_bias", gate_bias_init,
(self.num_output_features,), jnp.float32)
elif self.gate_type == "full":
self.gate_layer = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
elif self.gate_type == "lstm":
self.input_gate = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
self.forget_gate = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
else:
raise ValueError("Unsupported gate_type: %s" % self.gate_type)
def _gate(self, y_hidden: Array, state: Array, y_out: Array) -> Array:
"""Compute the value to use for the gate."""
if self.gate_type == "residual":
# Residual connection: just add y_out to the state.
logging.info("mlp: residual")
return state + y_out
elif self.gate_type == "bias":
# Simple gate: use a gru_style gate with a learned bias (no kernel).
bias = jnp.asarray(self.gate_bias, dtype=self.dtype)
bias = jnp.reshape(bias, (1,) * (y_out.ndim - 1) + (-1,)) # batch dims.
g = jax.nn.sigmoid(bias)
logging.info("mlp: gate bias = %r", g)
return (state * g) + (y_out * (1 - g))
elif self.gate_type == "full":
# Normal GRU style gate -- compute g using both a kernel and bias.
g = jax.nn.sigmoid(self.gate_layer(y_hidden) + 1) # biased to remember
logging.info("mlp: gate full = %r", g)
return (state * g) + (y_out * (1 - g))
elif self.gate_type == "lstm":
# LSTM style gate with input and forget gates.
fg = jax.nn.sigmoid(self.forget_gate(y_hidden) + 1) # biased to remember
ig = jax.nn.sigmoid(self.input_gate(y_hidden) - 1)
logging.info("mlp: gate lstm = %r, %r", ig, fg)
return (state * fg) + (y_out * ig)
else:
raise ValueError("Unsupported gate type %s" % self.gate_type)
def __call__(self, x: Array, state: Optional[Array],
apply_dropout: bool = False,
dropout_rate: float = 0.0,
drop_tile_shape: Optional[Shape] = None,
rng_function: Optional[Callable[[], Any]] = None) -> Array:
"""Apply the multi-layer perceptron to the input x.
For simple MLPs, returns f(x), where f is the MLP function.
For resnets and gated architectures, it returns
state + f(x) -- for resnet.
g*state + (1-g)*f(x) -- for gated architecture, where g is the gate.
Args:
x: The input to the MLP.
state: The prior value, if this MLP is used as part of a resnet or gated
architecture.
apply_dropout: If true, applies dropout to the result.
dropout_rate: The dropout rate to use.
drop_tile_shape: The dropout tile shape.
rng_function: Gets a random number seed for dropout.
Returns:
The combination of f(x) and the (optional) prior state.
"""
x = jnp.asarray(x, self.dtype)
hidden_act_fun = get_activation_function(self.hidden_activation)
final_act_fun = get_activation_function(self.final_activation)
if self.hidden_layers:
# Apply some number of hidden layers.
y = x
for layer in self.hidden_layers:
logging.info("mlp: hidden %d, %s", self.num_hidden_units,
self.hidden_activation)
y = hidden_act_fun(layer(y))
else:
# Apply the hidden activation function to the input.
logging.info("mlp: activation = %s", self.hidden_activation)
y = hidden_act_fun(x)
y_hidden = y # The hidden layer right before the output.
logging.info("mlp: final activation = %s", self.final_activation)
y_out = self.output_layer(y_hidden) # The MLP final output.
y_out = final_act_fun(y_out) # Apply final activation function.
logging.info("mlp: final = %r", y_out)
# Optionally apply dropout to the output.
if apply_dropout:
if drop_tile_shape is None:
raise ValueError("drop_tile_shape must be specified for dropout.")
if rng_function is None:
raise ValueError("rng_function must be specified for dropout.")
logging.info("mlp: dropout rate = %s", dropout_rate)
y_out = tiled_dropout(
y_out, shape=drop_tile_shape, dropout_rate=dropout_rate,
rng_function=rng_function, deterministic=False)
if state is None:
# Simple MLP. No gate to combine y_out with the state.
assert self.gate_type is None
logging.info("mlp: gate type = None.")
return y_out
# When using state, gate_type must be specified.
assert self.gate_type is not None
return self._gate(y_hidden, state, y_out)
# Modified slightly from the flax implementation.
@gin.configurable
class LayerNorm(nn.Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
Operates on the last axis of the input data.
It normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Attributes:
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma).
use_mean: If True, compute and adjust for the mean.
Note that that T5X layernorm does not use the mean.
Empirically, ignoring the mean can stabilize learning in transformers.
use_scalar_scale_bias: If True, using a single scalar for scale & bias.
enable_layernorm: If False, does not perform layernorm.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
"""
epsilon: float = 1e-6
dtype: Any = jnp.float32
use_scale: bool = True # Apply a learned scale.
use_bias: bool = False # Apply a learned bias.
use_mean: bool = False # Calculate and adjust for the mean.
use_scalar_scale_bias: bool = False # Learn a single scalar scale & bias.
enable_layernorm: bool = True # Turn off layernorm if false.
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
@nn.compact
def __call__(self, x):
"""Applies layer normalization on the input.
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
if not self.enable_layernorm:
return x
x = jnp.asarray(x)
# Calculate mean and variance at higher precision.
xf = jnp.asarray(x, jnp.float32)
if self.use_mean:
mean = jnp.mean(xf, axis=-1, keepdims=True)
xf = xf - mean
var = jnp.mean(lax.square(xf), axis=-1, keepdims=True)
mul = lax.rsqrt(var + self.epsilon)
# Rescale x
# if not use_mean, then rescale around zero instead. (A simplification.)
if self.use_mean:
y = (x - mean) * mul
else:
y = x * mul
if self.use_scalar_scale_bias:
# Learn a single scalar value for bias and scale.
# (Which mirrors the single value for mean and stddev above.)
num_scale_bias_features = 1
else:
# Learn a different value per neuron/feature for bias and scale.
num_scale_bias_features = x.shape[-1]
# Apply learned scale and bias.
if self.use_scale:
y = y * jnp.asarray(
self.param("scale", self.scale_init, (num_scale_bias_features,)),
dtype=self.dtype)
if self.use_bias:
y = y + jnp.asarray(
self.param("bias", self.bias_init, (num_scale_bias_features,)),
dtype=self.dtype)
return y.astype(self.dtype)