Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Bert srl (#2854)
Browse files Browse the repository at this point in the history
* initial working dataset reader for bert srl

* tests for everything, all good

* pylint, mypy

* clean up after monkeypatch

* docs, rename model

* import

* get names the right way around

* shift everything over to the srl reader

* docstring, update model to not take bert_dim param

* I love to lint

* sneaky configuration test

* joel's comments
  • Loading branch information
DeNeutoy authored Jun 7, 2019
1 parent 5b2066b commit 5f37783
Show file tree
Hide file tree
Showing 10 changed files with 535 additions and 11 deletions.
152 changes: 145 additions & 7 deletions allennlp/data/dataset_readers/semantic_role_labeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Dict, List, Iterable
from typing import Dict, List, Iterable, Tuple, Any

from overrides import overrides
from pytorch_pretrained_bert.tokenization import BertTokenizer

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
Expand All @@ -15,6 +16,88 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def _convert_tags_to_wordpiece_tags(tags: List[str], offsets: List[int]) -> List[str]:
"""
Converts a series of BIO tags to account for a wordpiece tokenizer,
extending/modifying BIO tags where appropriate to deal with words which
are split into multiple wordpieces by the tokenizer.
This is only used if you pass a `bert_model_name` to the dataset reader below.
Parameters
----------
tags : `List[str]`
The BIO formatted tags to convert to BIO tags for wordpieces
offsets : `List[int]`
The wordpiece offsets.
Returns
-------
The new BIO tags.
"""
# account for the fact the offsets are with respect to
# additional cls token at the start.
offsets = [x - 1 for x in offsets]
new_tags = []
j = 0
for i, offset in enumerate(offsets):
tag = tags[i]
is_o = tag == "O"
is_start = True
while j < offset:
if is_o:
new_tags.append("O")

elif tag.startswith("I"):
new_tags.append(tag)

elif is_start and tag.startswith("B"):
new_tags.append(tag)
is_start = False

elif tag.startswith("B"):
_, label = tag.split("-", 1)
new_tags.append("I-" + label)
j += 1

# Add O tags for cls and sep tokens.
return ["O"] + new_tags + ["O"]


def _convert_verb_indices_to_wordpiece_indices(verb_indices: List[int], offsets: List[int]): # pylint: disable=invalid-name
"""
Converts binary verb indicators to account for a wordpiece tokenizer,
extending/modifying BIO tags where appropriate to deal with words which
are split into multiple wordpieces by the tokenizer.
This is only used if you pass a `bert_model_name` to the dataset reader below.
Parameters
----------
verb_indices : `List[int]`
The binary verb indicators, 0 for not a verb, 1 for verb.
offsets : `List[int]`
The wordpiece offsets.
Returns
-------
The new verb indices.
"""
# account for the fact the offsets are with respect to
# additional cls token at the start.
offsets = [x - 1 for x in offsets]
j = 0
new_verb_indices = []
for i, offset in enumerate(offsets):
indicator = verb_indices[i]
while j < offset:
new_verb_indices.append(indicator)
j += 1

# Add 0 indicators for cls and sep tokens.
return [0] + new_verb_indices + [0]


@DatasetReader.register("srl")
class SrlReader(DatasetReader):
"""
Expand All @@ -37,6 +120,11 @@ class SrlReader(DatasetReader):
domain_identifier: ``str``, (default = None)
A string denoting a sub-domain of the Ontonotes 5.0 dataset to use. If present, only
conll files under paths containing this domain identifier will be processed.
bert_model_name : ``Optional[str]``, (default = None)
The BERT model to be wrapped. If you specify a bert_model here, then we will
assume you want to use BERT throughout; we will use the bert tokenizer,
and will expand your tags and verb indicators accordingly. If not,
the tokens will be indexed as normal with the token_indexers.
Returns
-------
Expand All @@ -46,11 +134,40 @@ class SrlReader(DatasetReader):
def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
domain_identifier: str = None,
lazy: bool = False) -> None:
lazy: bool = False,
bert_model_name: str = None) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._domain_identifier = domain_identifier

if bert_model_name is not None:
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
self.lowercase_input = "uncased" in bert_model_name
else:
self.bert_tokenizer = None
self.lowercase_input = False

def _wordpiece_tokenize_input(self, tokens: List[str]) -> Tuple[List[str], List[int]]:
"""
Convert a list of tokens to wordpiece tokens and offsets, as well as adding
BERT CLS and SEP tokens to the begining and end of the sentence.
"""
word_piece_tokens: List[str] = []
offsets = []
cumulative = 0
for token in tokens:
if self.lowercase_input:
token = token.lower()
word_pieces = self.bert_tokenizer.wordpiece_tokenizer.tokenize(token)
cumulative += len(word_pieces)
offsets.append(cumulative)
word_piece_tokens.extend(word_pieces)

wordpieces = ["[CLS]"] + word_piece_tokens + ["[SEP]"]

offsets = [x + 1 for x in offsets]
return wordpieces, offsets

@overrides
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
Expand Down Expand Up @@ -95,19 +212,40 @@ def text_to_instance(self, # type: ignore
to find arguments for.
"""
# pylint: disable=arguments-differ
metadata_dict: Dict[str, Any] = {}
if self.bert_tokenizer is not None:
wordpieces, offsets = self._wordpiece_tokenize_input([t.text for t in tokens])
new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets)
metadata_dict["offsets"] = offsets
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
text_field = TextField([Token(t, text_id=self.bert_tokenizer.vocab[t]) for t in wordpieces],
token_indexers=self._token_indexers)
verb_indicator = SequenceLabelField(new_verbs, text_field)

else:
text_field = TextField(tokens, token_indexers=self._token_indexers)
verb_indicator = SequenceLabelField(verb_label, text_field)

fields: Dict[str, Field] = {}
text_field = TextField(tokens, token_indexers=self._token_indexers)
fields['tokens'] = text_field
fields['verb_indicator'] = SequenceLabelField(verb_label, text_field)
fields['verb_indicator'] = verb_indicator

if all([x == 0 for x in verb_label]):
verb = None
else:
verb = tokens[verb_label.index(1)].text
metadata_dict = {"words": [x.text for x in tokens],
"verb": verb}

metadata_dict["words"] = [x.text for x in tokens]
metadata_dict["verb"] = verb

if tags:
fields['tags'] = SequenceLabelField(tags, text_field)
if self.bert_tokenizer is not None:
new_tags = _convert_tags_to_wordpiece_tags(tags, offsets)
fields['tags'] = SequenceLabelField(new_tags, text_field)
else:
fields['tags'] = SequenceLabelField(tags, text_field)
metadata_dict["gold_tags"] = tags

fields["metadata"] = MetadataField(metadata_dict)
return Instance(fields)
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from allennlp.models.bidirectional_lm import BidirectionalLanguageModel
from allennlp.models.language_model import LanguageModel
from allennlp.models.basic_classifier import BasicClassifier
from allennlp.models.srl_bert import SrlBert
Loading

0 comments on commit 5f37783

Please sign in to comment.