Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Refactor span extractors and unify forward. (#5160)
Browse files Browse the repository at this point in the history
* Refactor span extractors

* add SpanExtractorWithSpanWidthEmbedding

* update changelog

* fix blank lines

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
izhx and dirkgr authored May 3, 2021
1 parent 01b232f commit b533733
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 97 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](/~https://github.com/allenai/allennlp/pull/5172) for more details.

- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.

## Unreleased

### Fixed

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Optional

import torch
from overrides import overrides
from torch.nn.parameter import Parameter

from allennlp.common.checks import ConfigurationError
from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.nn import util


@SpanExtractor.register("bidirectional_endpoint")
class BidirectionalEndpointSpanExtractor(SpanExtractor):
class BidirectionalEndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Represents spans from a bidirectional encoder as a concatenation of two different
representations of the span endpoints, one for the forward direction of the encoder
Expand Down Expand Up @@ -79,12 +78,14 @@ def __init__(
bucket_widths: bool = False,
use_sentinels: bool = True,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._forward_combination = forward_combination
self._backward_combination = backward_combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

if self._input_dim % 2 != 0:
raise ConfigurationError(
Expand All @@ -93,25 +94,11 @@ def __init__(
"is bidirectional (and hence divisible by 2)."
)

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

self._use_sentinels = use_sentinels
if use_sentinels:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)]))
self._end_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)]))

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
unidirectional_dim = int(self._input_dim / 2)
forward_combined_dim = util.get_combined_dim(
Expand All @@ -128,8 +115,7 @@ def get_output_dim(self) -> int:
)
return forward_combined_dim + backward_combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -238,18 +224,4 @@ def forward(
# Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
span_embeddings = torch.cat([forward_spans, backward_spans], -1)

if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([span_embeddings, span_width_embeddings], -1)

if span_indices_mask is not None:
return span_embeddings * span_indices_mask.unsqueeze(-1)
return span_embeddings
51 changes: 11 additions & 40 deletions allennlp/modules/span_extractors/endpoint_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import Optional

import torch
from torch.nn.parameter import Parameter
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.nn import util
from allennlp.common.checks import ConfigurationError


@SpanExtractor.register("endpoint")
class EndpointSpanExtractor(SpanExtractor):
class EndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Represents spans as a combination of the embeddings of their endpoints. Additionally,
the width of the spans can be embedded and concatenated on to the final combination.
Expand Down Expand Up @@ -61,38 +59,25 @@ def __init__(
bucket_widths: bool = False,
use_exclusive_start_indices: bool = False,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._combination = combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

self._use_exclusive_start_indices = use_exclusive_start_indices
if use_exclusive_start_indices:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim)]))

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
combined_dim = util.get_combined_dim(self._combination, [self._input_dim, self._input_dim])
if self._span_width_embedding is not None:
return combined_dim + self._span_width_embedding.get_output_dim()
return combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -148,19 +133,5 @@ def forward(
combined_tensors = util.combine_tensors(
self._combination, [start_embeddings, end_embeddings]
)
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
combined_tensors = torch.cat([combined_tensors, span_width_embeddings], -1)

if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1)

return combined_tensors
45 changes: 29 additions & 16 deletions allennlp/modules/span_extractors/self_attentive_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import torch
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.nn import util


@SpanExtractor.register("self_attentive")
class SelfAttentiveSpanExtractor(SpanExtractor):
class SelfAttentiveSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Computes span representations by generating an unnormalized attention score for each
word in the document. Spans representations are computed with respect to these
Expand All @@ -23,6 +25,14 @@ class SelfAttentiveSpanExtractor(SpanExtractor):
input_dim : `int`, required.
The final dimension of the `sequence_tensor`.
num_width_embeddings : `int`, optional (default = `None`).
Specifies the number of buckets to use when representing
span width features.
span_width_embedding_dim : `int`, optional (default = `None`).
The embedding size for the span_width features.
bucket_widths : `bool`, optional (default = `False`).
Whether to bucket the span widths into log-space buckets. If `False`,
the raw span widths are used.
# Returns
Expand All @@ -33,22 +43,31 @@ class SelfAttentiveSpanExtractor(SpanExtractor):
over which they are normalized.
"""

def __init__(self, input_dim: int) -> None:
super().__init__()
self._input_dim = input_dim
def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1))

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
if self._span_width_embedding is not None:
return self._input_dim + self._span_width_embedding.get_output_dim()
return self._input_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.FloatTensor:
# shape (batch_size, sequence_length, 1)
Expand All @@ -72,10 +91,4 @@ def forward(
# Shape: (batch_size, num_spans, embedding_dim)
attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

if span_indices_mask is not None:
# Above we were masking the widths of spans with respect to the max
# span width in the batch. Here we are masking the spans which were
# originally passed in as padding.
return attended_text_embeddings * span_indices_mask.unsqueeze(-1)

return attended_text_embeddings
Loading

0 comments on commit b533733

Please sign in to comment.