diff --git a/allennlp/modules/feedforward.py b/allennlp/modules/feedforward.py index 2f652cbede6..352e1d060ed 100644 --- a/allennlp/modules/feedforward.py +++ b/allennlp/modules/feedforward.py @@ -25,13 +25,32 @@ class FeedForward(torch.nn.Module, FromParams): The output dimension of each of the `Linear` layers. If this is a single `int`, we use it for all `Linear` layers. If it is a `List[int]`, `len(hidden_dims)` must be `num_layers`. - activations : `Union[Callable, List[Callable]]`, required + activations : `Union[Activation, List[Activation]]`, required The activation function to use after each `Linear` layer. If this is a single function, - we use it after all `Linear` layers. If it is a `List[Callable]`, - `len(activations)` must be `num_layers`. + we use it after all `Linear` layers. If it is a `List[Activation]`, + `len(activations)` must be `num_layers`. Activation must have torch.nn.Module type. dropout : `Union[float, List[float]]`, optional (default = 0.0) If given, we will apply this amount of dropout after each layer. Semantics of `float` versus `List[float]` is the same as with other parameters. + + Example: + ``` + >>> FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2) + FeedForward( + (_activations): ModuleList( + (0): ReLU() + (1): ReLU() + ) + (_linear_layers): ModuleList( + (0): Linear(in_features=124, out_features=64, bias=True) + (1): Linear(in_features=64, out_features=32, bias=True) + ) + (_dropout): ModuleList( + (0): Dropout(p=0.2, inplace=False) + (1): Dropout(p=0.2, inplace=False) + ) + ) + ``` """ def __init__( @@ -62,7 +81,7 @@ def __init__( raise ConfigurationError( "len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers) ) - self._activations = activations + self._activations = torch.nn.ModuleList(activations) input_dims = [input_dim] + hidden_dims[:-1] linear_layers = [] for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims): diff --git a/allennlp/nn/activations.py b/allennlp/nn/activations.py index 7cac44cb69e..e00ed2e19a6 100644 --- a/allennlp/nn/activations.py +++ b/allennlp/nn/activations.py @@ -26,13 +26,15 @@ * ["tanhshrink"](https://pytorch.org/docs/master/nn.html#torch.nn.Tanhshrink) * ["selu"](https://pytorch.org/docs/master/nn.html#torch.nn.SELU) """ +from typing import Callable import torch +from overrides import overrides from allennlp.common import Registrable -class Activation(Registrable): +class Activation(torch.nn.Module, Registrable): """ Pytorch has a number of built-in activation functions. We group those here under a common type, just to make it easier to configure and instantiate them `from_params` using @@ -53,13 +55,34 @@ def __call__(self, tensor: torch.Tensor) -> torch.Tensor: raise NotImplementedError +class _ActivationLambda(torch.nn.Module): + """Wrapper around non PyTorch, lambda based activations to display them as modules whenever printing model.""" + + def __init__(self, func: Callable[[torch.Tensor], torch.Tensor], name: str): + super().__init__() + self._name = name + self._func = func + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._func(x) + + @overrides + def _get_name(self): + return self._name + + # There are no classes to decorate, so we hack these into Registrable._registry. # If you want to instantiate it, you can do like this: # Activation.by_name('relu')() Registrable._registry[Activation] = { - "linear": (lambda: lambda x: x, None), # type: ignore - "mish": (lambda: lambda x: x * torch.tanh(torch.nn.functional.softplus(x)), None), # type: ignore - "swish": (lambda: lambda x: x * torch.sigmoid(x), None), # type: ignore + "linear": (lambda: _ActivationLambda(lambda x: x, "Linear"), None), # type: ignore + "mish": ( # type: ignore + lambda: _ActivationLambda( + lambda x: x * torch.tanh(torch.nn.functional.softplus(x)), "Mish" + ), + None, + ), + "swish": (lambda: _ActivationLambda(lambda x: x * torch.sigmoid(x), "Swish"), None), # type: ignore "relu": (torch.nn.ReLU, None), "relu6": (torch.nn.ReLU6, None), "elu": (torch.nn.ELU, None), diff --git a/allennlp/tests/modules/feedforward_test.py b/allennlp/tests/modules/feedforward_test.py index e1a990ecb20..235da4f3cf2 100644 --- a/allennlp/tests/modules/feedforward_test.py +++ b/allennlp/tests/modules/feedforward_test.py @@ -1,4 +1,5 @@ from numpy.testing import assert_almost_equal +import inspect import pytest import torch @@ -65,3 +66,38 @@ def test_forward_gives_correct_output(self): # This output was checked by hand - ReLU makes output after first hidden layer [0, 0, 0], # which then gets a bias added in the second layer to be [1, 1, 1]. assert_almost_equal(output, [[1, 1, 1]]) + + def test_textual_representation_contains_activations(self): + params = Params( + { + "input_dim": 2, + "hidden_dims": 3, + "activations": ["linear", "relu", "swish"], + "num_layers": 3, + } + ) + feedforward = FeedForward.from_params(params) + expected_text_representation = inspect.cleandoc( + """ + FeedForward( + (_activations): ModuleList( + (0): Linear() + (1): ReLU() + (2): Swish() + ) + (_linear_layers): ModuleList( + (0): Linear(in_features=2, out_features=3, bias=True) + (1): Linear(in_features=3, out_features=3, bias=True) + (2): Linear(in_features=3, out_features=3, bias=True) + ) + (_dropout): ModuleList( + (0): Dropout(p=0.0, inplace=False) + (1): Dropout(p=0.0, inplace=False) + (2): Dropout(p=0.0, inplace=False) + ) + ) + """ + ) + actual_text_representation = str(feedforward) + + self.assertEqual(actual_text_representation, expected_text_representation) diff --git a/allennlp/tests/modules/seq2seq_encoders/compose_encoder_test.py b/allennlp/tests/modules/seq2seq_encoders/compose_encoder_test.py index ca380217a53..764b6c661b8 100644 --- a/allennlp/tests/modules/seq2seq_encoders/compose_encoder_test.py +++ b/allennlp/tests/modules/seq2seq_encoders/compose_encoder_test.py @@ -34,7 +34,7 @@ def is_bidirectional(self) -> bool: def _make_feedforward(input_dim, output_dim): return FeedForwardEncoder( FeedForward( - input_dim=input_dim, num_layers=1, activations=torch.relu, hidden_dims=output_dim + input_dim=input_dim, num_layers=1, activations=torch.nn.ReLU(), hidden_dims=output_dim ) ) diff --git a/allennlp/tests/modules/seq2seq_encoders/feedforward_encoder_test.py b/allennlp/tests/modules/seq2seq_encoders/feedforward_encoder_test.py index 996fc270970..d8b54f2f14f 100644 --- a/allennlp/tests/modules/seq2seq_encoders/feedforward_encoder_test.py +++ b/allennlp/tests/modules/seq2seq_encoders/feedforward_encoder_test.py @@ -9,7 +9,9 @@ class TestFeedforwardEncoder(AllenNlpTestCase): def test_get_dimension_is_correct(self): - feedforward = FeedForward(input_dim=10, num_layers=1, hidden_dims=10, activations="linear") + feedforward = FeedForward( + input_dim=10, num_layers=1, hidden_dims=10, activations=Activation.by_name("linear")() + ) encoder = FeedForwardEncoder(feedforward) assert encoder.get_input_dim() == feedforward.get_input_dim() assert encoder.get_output_dim() == feedforward.get_output_dim()