Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Display activation functions as modules. (#4045)
Browse files Browse the repository at this point in the history
* Display activations as modules.

* Fix tests and make changes in doc.

* Fix parameter type.

* Fix lambda based activation name displaying.

* Fix formatting.

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
3 people authored Apr 30, 2020
1 parent be53f07 commit 7cbeb6c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 10 deletions.
27 changes: 23 additions & 4 deletions allennlp/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,32 @@ class FeedForward(torch.nn.Module, FromParams):
The output dimension of each of the `Linear` layers. If this is a single `int`, we use
it for all `Linear` layers. If it is a `List[int]`, `len(hidden_dims)` must be
`num_layers`.
activations : `Union[Callable, List[Callable]]`, required
activations : `Union[Activation, List[Activation]]`, required
The activation function to use after each `Linear` layer. If this is a single function,
we use it after all `Linear` layers. If it is a `List[Callable]`,
`len(activations)` must be `num_layers`.
we use it after all `Linear` layers. If it is a `List[Activation]`,
`len(activations)` must be `num_layers`. Activation must have torch.nn.Module type.
dropout : `Union[float, List[float]]`, optional (default = 0.0)
If given, we will apply this amount of dropout after each layer. Semantics of `float`
versus `List[float]` is the same as with other parameters.
Example:
```
>>> FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2)
FeedForward(
(_activations): ModuleList(
(0): ReLU()
(1): ReLU()
)
(_linear_layers): ModuleList(
(0): Linear(in_features=124, out_features=64, bias=True)
(1): Linear(in_features=64, out_features=32, bias=True)
)
(_dropout): ModuleList(
(0): Dropout(p=0.2, inplace=False)
(1): Dropout(p=0.2, inplace=False)
)
)
```
"""

def __init__(
Expand Down Expand Up @@ -62,7 +81,7 @@ def __init__(
raise ConfigurationError(
"len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers)
)
self._activations = activations
self._activations = torch.nn.ModuleList(activations)
input_dims = [input_dim] + hidden_dims[:-1]
linear_layers = []
for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
Expand Down
31 changes: 27 additions & 4 deletions allennlp/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
* ["tanhshrink"](https://pytorch.org/docs/master/nn.html#torch.nn.Tanhshrink)
* ["selu"](https://pytorch.org/docs/master/nn.html#torch.nn.SELU)
"""
from typing import Callable

import torch
from overrides import overrides

from allennlp.common import Registrable


class Activation(Registrable):
class Activation(torch.nn.Module, Registrable):
"""
Pytorch has a number of built-in activation functions. We group those here under a common
type, just to make it easier to configure and instantiate them `from_params` using
Expand All @@ -53,13 +55,34 @@ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
raise NotImplementedError


class _ActivationLambda(torch.nn.Module):
"""Wrapper around non PyTorch, lambda based activations to display them as modules whenever printing model."""

def __init__(self, func: Callable[[torch.Tensor], torch.Tensor], name: str):
super().__init__()
self._name = name
self._func = func

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._func(x)

@overrides
def _get_name(self):
return self._name


# There are no classes to decorate, so we hack these into Registrable._registry.
# If you want to instantiate it, you can do like this:
# Activation.by_name('relu')()
Registrable._registry[Activation] = {
"linear": (lambda: lambda x: x, None), # type: ignore
"mish": (lambda: lambda x: x * torch.tanh(torch.nn.functional.softplus(x)), None), # type: ignore
"swish": (lambda: lambda x: x * torch.sigmoid(x), None), # type: ignore
"linear": (lambda: _ActivationLambda(lambda x: x, "Linear"), None), # type: ignore
"mish": ( # type: ignore
lambda: _ActivationLambda(
lambda x: x * torch.tanh(torch.nn.functional.softplus(x)), "Mish"
),
None,
),
"swish": (lambda: _ActivationLambda(lambda x: x * torch.sigmoid(x), "Swish"), None), # type: ignore
"relu": (torch.nn.ReLU, None),
"relu6": (torch.nn.ReLU6, None),
"elu": (torch.nn.ELU, None),
Expand Down
36 changes: 36 additions & 0 deletions allennlp/tests/modules/feedforward_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from numpy.testing import assert_almost_equal
import inspect
import pytest
import torch

Expand Down Expand Up @@ -65,3 +66,38 @@ def test_forward_gives_correct_output(self):
# This output was checked by hand - ReLU makes output after first hidden layer [0, 0, 0],
# which then gets a bias added in the second layer to be [1, 1, 1].
assert_almost_equal(output, [[1, 1, 1]])

def test_textual_representation_contains_activations(self):
params = Params(
{
"input_dim": 2,
"hidden_dims": 3,
"activations": ["linear", "relu", "swish"],
"num_layers": 3,
}
)
feedforward = FeedForward.from_params(params)
expected_text_representation = inspect.cleandoc(
"""
FeedForward(
(_activations): ModuleList(
(0): Linear()
(1): ReLU()
(2): Swish()
)
(_linear_layers): ModuleList(
(0): Linear(in_features=2, out_features=3, bias=True)
(1): Linear(in_features=3, out_features=3, bias=True)
(2): Linear(in_features=3, out_features=3, bias=True)
)
(_dropout): ModuleList(
(0): Dropout(p=0.0, inplace=False)
(1): Dropout(p=0.0, inplace=False)
(2): Dropout(p=0.0, inplace=False)
)
)
"""
)
actual_text_representation = str(feedforward)

self.assertEqual(actual_text_representation, expected_text_representation)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def is_bidirectional(self) -> bool:
def _make_feedforward(input_dim, output_dim):
return FeedForwardEncoder(
FeedForward(
input_dim=input_dim, num_layers=1, activations=torch.relu, hidden_dims=output_dim
input_dim=input_dim, num_layers=1, activations=torch.nn.ReLU(), hidden_dims=output_dim
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

class TestFeedforwardEncoder(AllenNlpTestCase):
def test_get_dimension_is_correct(self):
feedforward = FeedForward(input_dim=10, num_layers=1, hidden_dims=10, activations="linear")
feedforward = FeedForward(
input_dim=10, num_layers=1, hidden_dims=10, activations=Activation.by_name("linear")()
)
encoder = FeedForwardEncoder(feedforward)
assert encoder.get_input_dim() == feedforward.get_input_dim()
assert encoder.get_output_dim() == feedforward.get_output_dim()
Expand Down

0 comments on commit 7cbeb6c

Please sign in to comment.