Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Max out span extractor #5520

Merged
merged 9 commits into from
Jan 5, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added a way to resize the vocabulary in the T5 module
- Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings.
- Added a `MaxPoolingSpanExtractor`. This `SpanExtractor` represents each span by a component wise max-pooling-operation.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions allennlp/modules/span_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from allennlp.modules.span_extractors.bidirectional_endpoint_span_extractor import (
BidirectionalEndpointSpanExtractor,
)
from allennlp.modules.span_extractors.max_pooling_span_extractor import MaxPoolingSpanExtractor
131 changes: 131 additions & 0 deletions allennlp/modules/span_extractors/max_pooling_span_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.nn import util
from allennlp.nn.util import masked_max


@SpanExtractor.register("max_pooling")
class MaxPoolingSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Represents spans through the application of a dimension-wise max-pooling operation.
Given a span x_i, ..., x_j with i,j as span_start and span_end, each dimension d
of the resulting span s is computed via s_d = max(x_id, ..., x_jd).

Elements masked-out by sequence_mask are ignored when max-pooling is computed.
Span representations of masked out span_indices by span_mask are set to '0.'

Registered as a `SpanExtractor` with name "max_pooling".

# Parameters

input_dim : `int`, required.
The final dimension of the `sequence_tensor`.
num_width_embeddings : `int`, optional (default = `None`).
Specifies the number of buckets to use when representing
span width features.
span_width_embedding_dim : `int`, optional (default = `None`).
The embedding size for the span_width features.
bucket_widths : `bool`, optional (default = `False`).
Whether to bucket the span widths into log-space buckets. If `False`,
the raw span widths are used.

# Returns

max_pooling_text_embeddings : `torch.FloatTensor`.
A tensor of shape (batch_size, num_spans, input_dim), which each span representation
is the result of a max-pooling operation.

"""

def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)

def get_output_dim(self) -> int:
if self._span_width_embedding is not None:
return self._input_dim + self._span_width_embedding.get_output_dim()
return self._input_dim

def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.FloatTensor:

if sequence_tensor.size(-1) != self._input_dim:
raise ValueError(
f"Dimension mismatch expected ({sequence_tensor.size(-1)}) "
f"received ({self._input_dim})."
)

if sequence_tensor.shape[1] <= span_indices.max() or span_indices.min() < 0:
raise IndexError(
f"Span index out of range, max index ({span_indices.max()}) "
f"or min index ({span_indices.min()}) "
f"not valid for sequence of length ({sequence_tensor.shape[1]})."
)

if (span_indices[:, :, 0] > span_indices[:, :, 1]).any():
raise IndexError(
"Span start above span end",
)

# Calculate the maximum sequence length for each element in batch.
# If span_end indices are above these length, we adjust the indices in adopted_span_indices
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
if sequence_mask is not None:
# shape (batch_size)
sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
else:
# shape (batch_size), filled with the sequence length size of the sequence_tensor.
sequence_lengths = torch.ones_like(
sequence_tensor[:, 0, 0], dtype=torch.long
) * sequence_tensor.size(1)

adopted_span_indices = torch.tensor(span_indices, device=span_indices.device)
epwalsh marked this conversation as resolved.
Show resolved Hide resolved

for b in range(sequence_lengths.shape[0]):
adopted_span_indices[b, :, 1][adopted_span_indices[b, :, 1] >= sequence_lengths[b]] = (
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
sequence_lengths[b] - 1
)

# Raise Error if span indices were completly masked by sequence mask.
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
# We only adjust span_end to the last valid index, so if span_end is below span_start,
# both were above the max index:

if (adopted_span_indices[:, :, 0] > adopted_span_indices[:, :, 1]).any():
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
raise IndexError(
"Span indices were masked out entirely by sequence mask",
)

# span_vals <- (batch x num_spans x max_span_length x dim)
span_vals, span_mask = util.batched_span_select(sequence_tensor, adopted_span_indices)
epwalsh marked this conversation as resolved.
Show resolved Hide resolved

# The application of masked_mask requires a mask of the same shape as span_vals
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
# We repeat the mask along the last dimension (embedding dimension)
repeat_dim = len(span_vals.shape) - 1
repeat_idx = [1] * (repeat_dim) + [span_vals.shape[-1]]

# ext_span_mask <- (batch x num_spans x max_span_length x dim)
# ext_span_mask True for values in span, False for masked out values
ext_span_mask = span_mask.unsqueeze(repeat_dim).repeat(repeat_idx)

# max_out <- (batch x num_spans x dim)
max_out = masked_max(span_vals, ext_span_mask, dim=-2)

return max_out
184 changes: 184 additions & 0 deletions tests/modules/span_extractors/max_pooling_span_extractor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import pytest
import torch

from allennlp.common.params import Params
from allennlp.modules.span_extractors import SpanExtractor
from allennlp.modules.span_extractors.max_pooling_span_extractor import MaxPoolingSpanExtractor


class TestMaxPoolingSpanExtractor:
def test_locally_span_extractor_can_build_from_params(self):
params = Params(
{
"type": "max_pooling",
"input_dim": 3,
"num_width_embeddings": 5,
"span_width_embedding_dim": 3,
}
)
extractor = SpanExtractor.from_params(params)
assert isinstance(extractor, MaxPoolingSpanExtractor)
assert extractor.get_output_dim() == 6

def test_max_values_extracted(self):
# Test if max_pooling is correctly applied
# We use a high dimensional random vector and assume that a randomly correct result is too unlikely
sequence_tensor = torch.randn([2, 10, 30])
extractor = MaxPoolingSpanExtractor(30)

indices = torch.LongTensor([[[1, 1], [2, 4], [9, 9]], [[0, 1], [4, 4], [0, 9]]])
span_representations = extractor(sequence_tensor, indices)

assert list(span_representations.size()) == [2, 3, 30]
assert extractor.get_output_dim() == 30
assert extractor.get_input_dim() == 30

# We iterate over the tensor to compare the span extractors's results
# with the results of python max operation over each dimension for each span and for each batch
# For each batch
for batch, X in enumerate(indices):
# For each defined span index
for indices_ind, span_def in enumerate(X):

# original features of current tested span
# span_width x embedding dim (30)
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

# comparison for each dimension
for i in range(extractor.get_output_dim()):
# get the features for dimension i of current span
features_from_span = span_features_complete[:, i]
real_max_value = max(features_from_span)

extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
f"not the maximum element."
)

def test_sequence_mask_correct_excluded(self):
# Check if span indices masked out by the sequence mask are ignored when computing
# the span representations. For this test span_start is valid, but span_end is masked out.

sequence_tensor = torch.randn([2, 6, 30])

extractor = MaxPoolingSpanExtractor(30)
indices = torch.LongTensor([[[1, 1], [3, 5], [2, 5]], [[0, 0], [0, 3], [4, 5]]])
# define sequence mak
seq_mask = torch.BoolTensor([[True] * 4 + [False] * 2, [True] * 5 + [False] * 1])

span_representations = extractor(sequence_tensor, indices, sequence_mask=seq_mask)

# After we computed the representations we set values to -inf
# to compute the "real" max-pooling with python's max function.
sequence_tensor[torch.logical_not(seq_mask)] = float("-inf")

# Comparison is similar to test_max_values_extracted
for batch, X in enumerate(indices):
for indices_ind, span_def in enumerate(X):

span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

for i, _ in enumerate(span_features_complete):
features_from_span = span_features_complete[:, i]
real_max_value = max(features_from_span)
extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
f"not the maximum element."
)

def test_span_mask_correct_excluded(self):
# All masked out span indices by span_mask should be '0'

sequence_tensor = torch.randn([2, 6, 10])

extractor = MaxPoolingSpanExtractor(10)
indices = torch.LongTensor([[[1, 1], [3, 5], [2, 5]], [[0, 0], [0, 3], [4, 5]]])

span_mask = torch.BoolTensor([[True] * 3, [False] * 3])

span_representations = extractor(
sequence_tensor,
indices,
span_indices_mask=span_mask,
)

# The span-mask masks out all indices in the last batch
# We check whether all span representations for this batch are '0'
X = indices[-1]
batch = -1
for indices_ind, span_def in enumerate(X):

span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

for i, _ in enumerate(span_features_complete):
real_max_value = torch.FloatTensor([0.0])
extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
f"not the maximum element."
)

def test_inconsistent_extractor_dimension_throws_exception(self):

sequence_tensor = torch.randn([2, 6, 10])
indices = torch.LongTensor([[[1, 1], [2, 4], [9, 9]], [[0, 1], [4, 4], [0, 9]]])

with pytest.raises(ValueError):
extractor = MaxPoolingSpanExtractor(9)
extractor(sequence_tensor, indices)

with pytest.raises(ValueError):
extractor = MaxPoolingSpanExtractor(11)
extractor(sequence_tensor, indices)

def test_span_indices_outside_sequence(self):

sequence_tensor = torch.randn([2, 6, 10])
indices = torch.LongTensor([[[6, 6], [2, 4]], [[0, 1], [4, 4]]])

with pytest.raises(IndexError):
extractor = MaxPoolingSpanExtractor(10)
extractor(sequence_tensor, indices)

indices = torch.LongTensor([[[5, 6], [2, 4]], [[0, 1], [4, 4]]])

with pytest.raises(IndexError):
extractor = MaxPoolingSpanExtractor(10)
extractor(sequence_tensor, indices)

indices = torch.LongTensor([[[-1, 0], [2, 4]], [[0, 1], [4, 4]]])

with pytest.raises(IndexError):
extractor = MaxPoolingSpanExtractor(10)
extractor(sequence_tensor, indices)

def test_span_start_below_span_end(self):

sequence_tensor = torch.randn([2, 6, 10])
indices = torch.LongTensor([[[4, 2], [2, 4], [1, 1]], [[0, 1], [4, 4], [1, 1]]])
with pytest.raises(IndexError):
extractor = MaxPoolingSpanExtractor(10)
extractor(sequence_tensor, indices)

def test_span_sequence_complete_masked(self):

sequence_tensor = torch.randn([2, 6, 10])
seq_mask = torch.BoolTensor([[True] * 2 + [False] * 4, [True] * 3 + [False] * 3])
indices = torch.LongTensor([[[5, 5]], [[4, 5]]])
with pytest.raises(IndexError):
extractor = MaxPoolingSpanExtractor(10)
extractor(sequence_tensor, indices, sequence_mask=seq_mask)