Skip to content

Commit

Permalink
use starting offsets in the srl model so output is wellformed (allena…
Browse files Browse the repository at this point in the history
…i#2972)

* use starting offsets in the srl model so output is wellformed

* fix bug in viterbi_decode for constrained start and end sequences

* add failing tests for srl models without viterbi constraint

* fix srl models to use start transitions for bio tagging

* lint

* fix random bug surfaced in openie predictor

* fix more openie tests

* clarify comments, PR feedback
  • Loading branch information
DeNeutoy authored and reiyw committed Nov 12, 2019
1 parent a2ae857 commit 1071a0d
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 28 deletions.
51 changes: 37 additions & 14 deletions allennlp/data/dataset_readers/semantic_role_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def _convert_tags_to_wordpiece_tags(tags: List[str], offsets: List[int]) -> List
-------
The new BIO tags.
"""
# account for the fact the offsets are with respect to
# additional cls token at the start.
offsets = [x - 1 for x in offsets]
new_tags = []
j = 0
for i, offset in enumerate(offsets):
Expand Down Expand Up @@ -83,9 +80,6 @@ def _convert_verb_indices_to_wordpiece_indices(verb_indices: List[int], offsets:
-------
The new verb indices.
"""
# account for the fact the offsets are with respect to
# additional cls token at the start.
offsets = [x - 1 for x in offsets]
j = 0
new_verb_indices = []
for i, offset in enumerate(offsets):
Expand Down Expand Up @@ -129,7 +123,6 @@ class SrlReader(DatasetReader):
Returns
-------
A ``Dataset`` of ``Instances`` for Semantic Role Labelling.
"""
def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
Expand All @@ -147,26 +140,56 @@ def __init__(self,
self.bert_tokenizer = None
self.lowercase_input = False

def _wordpiece_tokenize_input(self, tokens: List[str]) -> Tuple[List[str], List[int]]:
def _wordpiece_tokenize_input(self, tokens: List[str]) -> Tuple[List[str], List[int], List[int]]:
"""
Convert a list of tokens to wordpiece tokens and offsets, as well as adding
BERT CLS and SEP tokens to the begining and end of the sentence.
A slight oddity with this function is that it also returns the wordpiece offsets
corresponding to the _start_ of words as well as the end.
We need both of these offsets (or at least, it's easiest to use both), because we need
to convert the labels to tags using the end_offsets. However, when we are decoding a
BIO sequence inside the SRL model itself, it's important that we use the start_offsets,
because otherwise we might select an ill-formed BIO sequence from the BIO sequence on top of
wordpieces (this happens in the case that a word is split into multiple word pieces,
and then we take the last tag of the word, which might correspond to, e.g, I-V, which
would not be allowed as it is not preceeded by a B tag).
For example:
`annotate` will be bert tokenized as ["anno", "##tate"].
If this is tagged as [B-V, I-V] as it should be, we need to select the
_first_ wordpiece label to be the label for the token, because otherwise
we may end up with invalid tag sequences (we cannot start a new tag with an I).
Returns
-------
wordpieces : List[str]
The BERT wordpieces from the words in the sentence.
end_offsets : List[int]
Indices into wordpieces such that `[wordpieces[i] for i in end_offsets]`
results in the end wordpiece of each word being chosen.
start_offsets : List[int]
Indices into wordpieces such that `[wordpieces[i] for i in start_offsets]`
results in the start wordpiece of each word being chosen.
"""
word_piece_tokens: List[str] = []
offsets = []
end_offsets = []
start_offsets = []
cumulative = 0
for token in tokens:
if self.lowercase_input:
token = token.lower()
word_pieces = self.bert_tokenizer.wordpiece_tokenizer.tokenize(token)
start_offsets.append(cumulative + 1)
cumulative += len(word_pieces)
offsets.append(cumulative)
end_offsets.append(cumulative)
word_piece_tokens.extend(word_pieces)

wordpieces = ["[CLS]"] + word_piece_tokens + ["[SEP]"]

offsets = [x + 1 for x in offsets]
return wordpieces, offsets
return wordpieces, end_offsets, start_offsets

@overrides
def _read(self, file_path: str):
Expand Down Expand Up @@ -214,9 +237,9 @@ def text_to_instance(self, # type: ignore
# pylint: disable=arguments-differ
metadata_dict: Dict[str, Any] = {}
if self.bert_tokenizer is not None:
wordpieces, offsets = self._wordpiece_tokenize_input([t.text for t in tokens])
wordpieces, offsets, start_offsets = self._wordpiece_tokenize_input([t.text for t in tokens])
new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets)
metadata_dict["offsets"] = offsets
metadata_dict["offsets"] = start_offsets
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
text_field = TextField([Token(t, text_id=self.bert_tokenizer.vocab[t]) for t in wordpieces],
Expand Down
25 changes: 24 additions & 1 deletion allennlp/models/semantic_role_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
predictions_list = [all_predictions]
all_tags = []
transition_matrix = self.get_viterbi_pairwise_potentials()
start_transitions = self.get_start_transitions()
for predictions, length in zip(predictions_list, sequence_lengths):
max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix)
max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix,
allowed_start_transitions=start_transitions)
tags = [self.vocab.get_token_from_index(x, namespace="labels")
for x in max_likelihood_sequence]
all_tags.append(tags)
Expand Down Expand Up @@ -226,6 +228,27 @@ def get_viterbi_pairwise_potentials(self):
transition_matrix[i, j] = float("-inf")
return transition_matrix

def get_start_transitions(self):
"""
In the BIO sequence, we cannot start the sequence with an I-XXX tag.
This transition sequence is passed to viterbi_decode to specify this constraint.
Returns
-------
start_transitions : torch.Tensor
The pairwise potentials between a START token and
the first token of the sequence.
"""
all_labels = self.vocab.get_index_to_token_vocabulary("labels")
num_labels = len(all_labels)

start_transitions = torch.zeros(num_labels)

for i, label in all_labels.items():
if label[0] == "I":
start_transitions[i] = float("-inf")

return start_transitions

def write_to_conll_eval_file(prediction_file: TextIO,
gold_file: TextIO,
Expand Down
48 changes: 43 additions & 5 deletions allennlp/models/srl_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def forward(self, # type: ignore
A torch tensor representing the sequence of integer gold class labels
of shape ``(batch_size, num_tokens)``
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
metadata containg the original words in the sentence and the verb to compute the
frame for, under 'words' and 'verb' keys, respectively.
metadata containg the original words in the sentence, the verb to compute the
frame for, and start offsets for converting wordpieces back to a sequence of words,
under 'words', 'verb' and 'offsets' keys, respectively.
Returns
-------
Expand Down Expand Up @@ -136,6 +137,18 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
Does constrained viterbi decoding on class probabilities output in :func:`forward`. The
constraint simply specifies that the output tags must be a valid BIO sequence. We add a
``"tags"`` key to the dictionary with the result.
NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
decoding produces low quality output if you decode on top of word representations directly,
because the model gets confused by the 'missing' positions (which is sensible as it is trained
to perform tagging on wordpieces, not words).
Secondly, it's important that the indices we use to recover words from the wordpieces are the
start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
when we select out the word tags from the wordpiece tags. This happens in the case that a word
is split into multiple word pieces, and then we take the last tag of the word, which might
correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
"""
all_predictions = output_dict['class_probabilities']
sequence_lengths = get_lengths_from_binary_sequence_mask(output_dict["mask"]).data.tolist()
Expand All @@ -147,17 +160,19 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
wordpiece_tags = []
word_tags = []
transition_matrix = self.get_viterbi_pairwise_potentials()
start_transitions = self.get_start_transitions()
# **************** Different ********************
# We add in the offsets here so we can compute the un-wordpieced tags.
for predictions, length, offsets in zip(predictions_list,
sequence_lengths,
output_dict["wordpiece_offsets"]):
max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix)
max_likelihood_sequence, _ = viterbi_decode(predictions[:length], transition_matrix,
allowed_start_transitions=start_transitions)
tags = [self.vocab.get_token_from_index(x, namespace="labels")
for x in max_likelihood_sequence]

wordpiece_tags.append(tags)
# Offset due to exclusive end indices.
word_tags.append([tags[i - 1] for i in offsets])
word_tags.append([tags[i] for i in offsets])
output_dict['wordpiece_tags'] = wordpiece_tags
output_dict['tags'] = word_tags
return output_dict
Expand Down Expand Up @@ -199,3 +214,26 @@ def get_viterbi_pairwise_potentials(self):
if i != j and label[0] == 'I' and not previous_label == 'B' + label[1:]:
transition_matrix[i, j] = float("-inf")
return transition_matrix


def get_start_transitions(self):
"""
In the BIO sequence, we cannot start the sequence with an I-XXX tag.
This transition sequence is passed to viterbi_decode to specify this constraint.
Returns
-------
start_transitions : torch.Tensor
The pairwise potentials between a START token and
the first token of the sequence.
"""
all_labels = self.vocab.get_index_to_token_vocabulary("labels")
num_labels = len(all_labels)

start_transitions = torch.zeros(num_labels)

for i, label in all_labels.items():
if label[0] == "I":
start_transitions[i] = float("-inf")

return start_transitions
55 changes: 54 additions & 1 deletion allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def masked_flip(padded_sequence: torch.Tensor,

def viterbi_decode(tag_sequence: torch.Tensor,
transition_matrix: torch.Tensor,
tag_observations: Optional[List[int]] = None):
tag_observations: Optional[List[int]] = None,
allowed_start_transitions: torch.Tensor = None,
allowed_end_transitions: torch.Tensor = None):
"""
Perform Viterbi decoding in log space over a sequence given a transition matrix
specifying pairwise (transition) potentials between tags and a matrix of shape
Expand All @@ -421,6 +423,14 @@ def viterbi_decode(tag_sequence: torch.Tensor,
other, or those transitions are extremely unlikely. In this situation we log a
warning, but the responsibility for providing self-consistent evidence ultimately
lies with the user.
allowed_start_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags the START token
may transition *to*. If provided, additional transition constraints will be used for
determining the start element of the sequence.
allowed_end_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags may transition *to* the
end tag. If provided, additional transition constraints will be used for determining
the end element of the sequence.
Returns
-------
Expand All @@ -430,6 +440,37 @@ def viterbi_decode(tag_sequence: torch.Tensor,
The score of the viterbi path.
"""
sequence_length, num_tags = list(tag_sequence.size())

has_start_end_restrictions = allowed_end_transitions is not None or allowed_start_transitions is not None

if has_start_end_restrictions:

if allowed_end_transitions is None:
allowed_end_transitions = torch.zeros(num_tags)
if allowed_start_transitions is None:
allowed_start_transitions = torch.zeros(num_tags)

num_tags = num_tags + 2
new_transition_matrix = torch.zeros(num_tags, num_tags)
new_transition_matrix[:-2, :-2] = transition_matrix

# Start and end transitions are fully defined, but cannot transition between each other.
# pylint: disable=not-callable
allowed_start_transitions = torch.cat([allowed_start_transitions, torch.tensor([-math.inf, -math.inf])])
allowed_end_transitions = torch.cat([allowed_end_transitions, torch.tensor([-math.inf, -math.inf])])
# pylint: enable=not-callable

# First define how we may transition FROM the start and end tags.
new_transition_matrix[-2, :] = allowed_start_transitions
# We cannot transition from the end tag to any tag.
new_transition_matrix[-1, :] = -math.inf

new_transition_matrix[:, -1] = allowed_end_transitions
# We cannot transition to the start tag from any tag.
new_transition_matrix[:, -2] = -math.inf

transition_matrix = new_transition_matrix

if tag_observations:
if len(tag_observations) != sequence_length:
raise ConfigurationError("Observations were provided, but they were not the same length "
Expand All @@ -438,6 +479,15 @@ def viterbi_decode(tag_sequence: torch.Tensor,
else:
tag_observations = [-1 for _ in range(sequence_length)]


if has_start_end_restrictions:
tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
zero_sentinel = torch.zeros(1, num_tags)
extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
sequence_length = tag_sequence.size(0)

path_scores = []
path_indices = []

Expand Down Expand Up @@ -479,6 +529,9 @@ def viterbi_decode(tag_sequence: torch.Tensor,
viterbi_path.append(int(backward_timestep[viterbi_path[-1]]))
# Reverse the backward path.
viterbi_path.reverse()

if has_start_end_restrictions:
viterbi_path = viterbi_path[1:-1]
return viterbi_path, viterbi_score


Expand Down
8 changes: 4 additions & 4 deletions allennlp/predictors/open_information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def join_mwp(tags: List[str]) -> List[str]:
for tag in tags:
if "V" in tag:
# Create a continuous 'V' BIO span
prefix, _ = tag.split("-")
prefix, _ = tag.split("-", 1)
if verb_flag:
# Continue a verb label across the different predicate parts
prefix = 'I'
Expand Down Expand Up @@ -110,8 +110,8 @@ def merge_overlapping_predictions(tags1: List[str], tags2: List[str]) -> List[st
# spans which predicates' overlap

for tag1, tag2 in zip(tags1, tags2):
label1 = tag1.split("-")[-1]
label2 = tag2.split("-")[-1]
label1 = tag1.split("-", 1)[-1]
label2 = tag2.split("-", 1)[-1]
if (label1 == "V") or (label2 == "V"):
# Construct maximal predicate length -
# add predicate tag if any of the sequence predict it
Expand Down Expand Up @@ -164,7 +164,7 @@ def sanitize_label(label: str) -> str:
labels sometimes having some noise, as parentheses.
"""
if "-" in label:
prefix, suffix = label.split("-")
prefix, suffix = label.split("-", 1)
suffix = suffix.split("(")[-1]
return f"{prefix}-{suffix}"
else:
Expand Down
17 changes: 15 additions & 2 deletions allennlp/tests/data/dataset_readers/srl_dataset_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def setUp(self):
def test_convert_tags_to_wordpiece_tags(self):
# pylint: disable=protected-access
offsets = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
offsets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
original = ['B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1', 'I-ARG1',
'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'O']
wordpiece_tags = ['O', 'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
Expand All @@ -93,21 +94,33 @@ def test_convert_tags_to_wordpiece_tags(self):
assert converted == wordpiece_tags

offsets = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12]
offsets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]
converted = _convert_tags_to_wordpiece_tags(original, offsets)
assert converted == ['O', 'B-ARG0', 'I-ARG0', 'I-ARG0', 'B-V', 'B-ARG1',
'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'O']

offsets = [2, 4, 6]
offsets = [1, 3, 5]
original = ["B-ARG", "B-V", "O"]
converted = _convert_tags_to_wordpiece_tags(original, offsets)
assert converted == ['O', 'B-ARG', 'B-V', 'I-V', 'O', 'O', 'O']

offsets = [3, 4, 6]
offsets = [2, 3, 5]
original = ["B-ARG", "I-ARG", "O"]
converted = _convert_tags_to_wordpiece_tags(original, offsets)
assert converted == ['O', 'B-ARG', 'I-ARG', 'I-ARG', 'O', 'O', 'O']
# pylint: enable=protected-access


def test_wordpiece_tokenize_input(self):
wordpieces, offsets, start_offsets = self.reader._wordpiece_tokenize_input( # pylint: disable=protected-access
"This is a sentenceandsomepieces with a reallylongword".split(" "))

assert wordpieces == ['[CLS]', 'this', 'is', 'a', 'sentence', '##ands', '##ome',
'##piece', '##s', 'with', 'a', 'really', '##long', '##word', '[SEP]']
assert [wordpieces[i] for i in offsets] == ['this', 'is', 'a', '##s', 'with', 'a', '##word']
assert [wordpieces[i] for i in start_offsets] == ['this', 'is', 'a', 'sentence', 'with', 'a', 'really']


def test_read_from_file(self):
conll_reader = self.reader
instances = conll_reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'conll_2012' / 'subdomain')
Expand Down
Loading

0 comments on commit 1071a0d

Please sign in to comment.