Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add additive attention & unittest (#3238)
Browse files Browse the repository at this point in the history
* Add additive attention & unittest

* Fix test

* Fix pylint, typo & Add docs
  • Loading branch information
hawkeoni authored and DeNeutoy committed Sep 14, 2019
1 parent 07364c6 commit c732cbf
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 1 deletion.
3 changes: 2 additions & 1 deletion allennlp/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"""

from typing import Any, Dict, List
from collections import MutableMapping, OrderedDict
from collections.abc import MutableMapping
from collections import OrderedDict
import copy
import json
import logging
Expand Down
1 change: 1 addition & 0 deletions allennlp/modules/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from allennlp.modules.attention.attention import Attention
from allennlp.modules.attention.bilinear_attention import BilinearAttention
from allennlp.modules.attention.additive_attention import AdditiveAttention
from allennlp.modules.attention.cosine_attention import CosineAttention
from allennlp.modules.attention.dot_product_attention import DotProductAttention
from allennlp.modules.attention.legacy_attention import LegacyAttention
Expand Down
51 changes: 51 additions & 0 deletions allennlp/modules/attention/additive_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from overrides import overrides
import torch
from torch.nn.parameter import Parameter

from allennlp.modules.attention.attention import Attention


@Attention.register("additive")
class AdditiveAttention(Attention):
"""
Computes attention between a vector and a matrix using an additive attention function. This
function has two matrices ``W``, ``U`` and a vector ``V``. The similarity between the vector
``x`` and the matrix ``y`` is computed as ``V tanh(Wx + Uy)``.
This attention is often referred as concat or additive attention. It was introduced in
<https://arxiv.org/abs/1409.0473> by Bahdanau et al.
Parameters
----------
vector_dim : ``int``
The dimension of the vector, ``x``, described above. This is ``x.size()[-1]`` - the length
of the vector that will go into the similarity computation. We need this so we can build
the weight matrix correctly.
matrix_dim : ``int``
The dimension of the matrix, ``y``, described above. This is ``y.size()[-1]`` - the length
of the vector that will go into the similarity computation. We need this so we can build
the weight matrix correctly.
normalize : ``bool``, optional (default: ``True``)
If true, we normalize the computed similarities with a softmax, to return a probability
distribution for your attention. If false, this is just computing a similarity score.
"""
def __init__(self,
vector_dim: int,
matrix_dim: int,
normalize: bool = True) -> None:
super().__init__(normalize)
self._w_matrix = Parameter(torch.Tensor(vector_dim, vector_dim))
self._u_matrix = Parameter(torch.Tensor(matrix_dim, vector_dim))
self._v_vector = Parameter(torch.Tensor(vector_dim, 1))
self.reset_parameters()

def reset_parameters(self):
torch.nn.init.xavier_uniform_(self._w_matrix)
torch.nn.init.xavier_uniform_(self._u_matrix)
torch.nn.init.xavier_uniform_(self._v_vector)

@overrides
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
intermediate = vector.matmul(self._w_matrix).unsqueeze(1) + matrix.matmul(self._u_matrix)
intermediate = torch.tanh(intermediate)
return intermediate.matmul(self._v_vector).squeeze(2)
30 changes: 30 additions & 0 deletions allennlp/tests/modules/attention/additive_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# pylint: disable=no-self-use,invalid-name,protected-access
from numpy.testing import assert_almost_equal
import torch
from torch.nn.parameter import Parameter

from allennlp.common import Params
from allennlp.modules.attention import AdditiveAttention
from allennlp.common.testing import AllenNlpTestCase


class TestAdditiveAttention(AllenNlpTestCase):
def test_forward_does_an_additive_product(self):
params = Params({
'vector_dim': 2,
'matrix_dim': 3,
'normalize': False,
})
additive = AdditiveAttention.from_params(params)
additive._w_matrix = Parameter(torch.Tensor([[-0.2, 0.3], [-0.5, 0.5]]))
additive._u_matrix = Parameter(torch.Tensor([[0., 1.], [1., 1.], [1., -1.]]))
additive._v_vector = Parameter(torch.Tensor([[1.], [-1.]]))
vectors = torch.FloatTensor([[0.7, -0.8], [0.4, 0.9]])
matrices = torch.FloatTensor([
[[1., -1., 3.], [0.5, -0.3, 0.], [0.2, -1., 1.], [0.7, 0.8, -1.]],
[[-2., 3., -3.], [0.6, 0.2, 2.], [0.5, -0.4, -1.], [0.2, 0.2, 0.]]])
result = additive(vectors, matrices).detach().numpy()
assert result.shape == (2, 4)
assert_almost_equal(result, [
[1.975072, -0.04997836, 1.2176098, -0.9205586],
[-1.4851665, 1.489604, -1.890285, -1.0672251]])
5 changes: 5 additions & 0 deletions doc/api/allennlp.modules.attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ allennlp.modules.attention
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.modules.attention.additive_attention
:members:
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.modules.attention.cosine_attention
:members:
:undoc-members:
Expand Down

0 comments on commit c732cbf

Please sign in to comment.