This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add additive attention & unittest (#3238)
* Add additive attention & unittest * Fix test * Fix pylint, typo & Add docs
- Loading branch information
Showing
5 changed files
with
89 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from overrides import overrides | ||
import torch | ||
from torch.nn.parameter import Parameter | ||
|
||
from allennlp.modules.attention.attention import Attention | ||
|
||
|
||
@Attention.register("additive") | ||
class AdditiveAttention(Attention): | ||
""" | ||
Computes attention between a vector and a matrix using an additive attention function. This | ||
function has two matrices ``W``, ``U`` and a vector ``V``. The similarity between the vector | ||
``x`` and the matrix ``y`` is computed as ``V tanh(Wx + Uy)``. | ||
This attention is often referred as concat or additive attention. It was introduced in | ||
<https://arxiv.org/abs/1409.0473> by Bahdanau et al. | ||
Parameters | ||
---------- | ||
vector_dim : ``int`` | ||
The dimension of the vector, ``x``, described above. This is ``x.size()[-1]`` - the length | ||
of the vector that will go into the similarity computation. We need this so we can build | ||
the weight matrix correctly. | ||
matrix_dim : ``int`` | ||
The dimension of the matrix, ``y``, described above. This is ``y.size()[-1]`` - the length | ||
of the vector that will go into the similarity computation. We need this so we can build | ||
the weight matrix correctly. | ||
normalize : ``bool``, optional (default: ``True``) | ||
If true, we normalize the computed similarities with a softmax, to return a probability | ||
distribution for your attention. If false, this is just computing a similarity score. | ||
""" | ||
def __init__(self, | ||
vector_dim: int, | ||
matrix_dim: int, | ||
normalize: bool = True) -> None: | ||
super().__init__(normalize) | ||
self._w_matrix = Parameter(torch.Tensor(vector_dim, vector_dim)) | ||
self._u_matrix = Parameter(torch.Tensor(matrix_dim, vector_dim)) | ||
self._v_vector = Parameter(torch.Tensor(vector_dim, 1)) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.xavier_uniform_(self._w_matrix) | ||
torch.nn.init.xavier_uniform_(self._u_matrix) | ||
torch.nn.init.xavier_uniform_(self._v_vector) | ||
|
||
@overrides | ||
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor: | ||
intermediate = vector.matmul(self._w_matrix).unsqueeze(1) + matrix.matmul(self._u_matrix) | ||
intermediate = torch.tanh(intermediate) | ||
return intermediate.matmul(self._v_vector).squeeze(2) |
30 changes: 30 additions & 0 deletions
30
allennlp/tests/modules/attention/additive_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# pylint: disable=no-self-use,invalid-name,protected-access | ||
from numpy.testing import assert_almost_equal | ||
import torch | ||
from torch.nn.parameter import Parameter | ||
|
||
from allennlp.common import Params | ||
from allennlp.modules.attention import AdditiveAttention | ||
from allennlp.common.testing import AllenNlpTestCase | ||
|
||
|
||
class TestAdditiveAttention(AllenNlpTestCase): | ||
def test_forward_does_an_additive_product(self): | ||
params = Params({ | ||
'vector_dim': 2, | ||
'matrix_dim': 3, | ||
'normalize': False, | ||
}) | ||
additive = AdditiveAttention.from_params(params) | ||
additive._w_matrix = Parameter(torch.Tensor([[-0.2, 0.3], [-0.5, 0.5]])) | ||
additive._u_matrix = Parameter(torch.Tensor([[0., 1.], [1., 1.], [1., -1.]])) | ||
additive._v_vector = Parameter(torch.Tensor([[1.], [-1.]])) | ||
vectors = torch.FloatTensor([[0.7, -0.8], [0.4, 0.9]]) | ||
matrices = torch.FloatTensor([ | ||
[[1., -1., 3.], [0.5, -0.3, 0.], [0.2, -1., 1.], [0.7, 0.8, -1.]], | ||
[[-2., 3., -3.], [0.6, 0.2, 2.], [0.5, -0.4, -1.], [0.2, 0.2, 0.]]]) | ||
result = additive(vectors, matrices).detach().numpy() | ||
assert result.shape == (2, 4) | ||
assert_almost_equal(result, [ | ||
[1.975072, -0.04997836, 1.2176098, -0.9205586], | ||
[-1.4851665, 1.489604, -1.890285, -1.0672251]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters