Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Use the new tokenizers #3868

Merged
merged 40 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
592362b
Use the new tokenizers
dirkgr Feb 28, 2020
bc43649
Merge branch 'master' into Tokenizers
dirkgr Feb 28, 2020
2742a6e
Fix various problems with the tests
dirkgr Feb 28, 2020
02d9a80
Merge branch 'Tokenizers' of /~https://github.com/dirkgr/allennlp into …
dirkgr Feb 28, 2020
5d5c01d
Merge branch 'master' into Tokenizers
dirkgr Feb 28, 2020
d1b21d5
Matching implementation of intra_word_tokenize
dirkgr Mar 2, 2020
163568d
Update some of the tests
dirkgr Mar 2, 2020
a5be9d6
Merge branch 'Tokenizers' of /~https://github.com/dirkgr/allennlp into …
dirkgr Mar 2, 2020
45aa194
Merge remote-tracking branch 'origin/master' into Tokenizers
dirkgr Apr 14, 2020
aecff02
Fix changed API
dirkgr Apr 14, 2020
b7cd79f
Workaround for an API not implemented
dirkgr Apr 14, 2020
3209466
Fix some more tests, remove the tokenize_sentence_pair function
dirkgr Apr 14, 2020
125f44c
Merge branch 'master' into Tokenizers
dirkgr Apr 14, 2020
bca4b70
Merge branch 'master' into Tokenizers
dirkgr Apr 15, 2020
51775be
Merge branch 'master' into Tokenizers
dirkgr May 5, 2020
782f027
Makes all the tests succeed when run against the latest HF master
dirkgr May 5, 2020
b2cf84c
Depend on latest huggingface master
dirkgr May 5, 2020
39281bd
Merge remote-tracking branch 'origin/master' into Tokenizers
dirkgr May 5, 2020
f626c84
Formatting
dirkgr May 6, 2020
8ebd7e2
Make the tokenizers work even when we don't get a fast one
dirkgr May 6, 2020
cd046c6
Fall back to the old method of calculating offsets
dirkgr May 6, 2020
de59f3c
Incredibly, this passes tests
dirkgr May 8, 2020
4d2b578
Merge remote-tracking branch 'origin/master' into Tokenizers
dirkgr May 8, 2020
64ff7c5
Depend on a released version of transformers
dirkgr May 8, 2020
38117bc
Formatting
dirkgr May 8, 2020
450adf1
Update allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
dirkgr May 8, 2020
cb7f41f
Fix copy and paste error
dirkgr May 8, 2020
9cb142a
Be more flexible with the transformers version
dirkgr May 8, 2020
3fe2587
Merge branch 'Tokenizers' of /~https://github.com/dirkgr/allennlp into …
dirkgr May 8, 2020
f43f929
Productivity through formatting
dirkgr May 8, 2020
196925b
Merge branch 'master' into Tokenizers
dirkgr May 8, 2020
3ff38c5
Merge remote-tracking branch 'origin/master' into Tokenizers
dirkgr May 9, 2020
2b14550
`token.text` is now the word piece again
dirkgr May 9, 2020
3f89d99
New way of doing pairs
dirkgr May 11, 2020
617f5c7
Formatting
dirkgr May 12, 2020
450f0e2
Slight refactoring
dirkgr May 12, 2020
df6f7ec
Refactoring
dirkgr May 12, 2020
b6db7ca
Merge remote-tracking branch 'origin/master' into Tokenizers
dirkgr May 12, 2020
1163b66
Formatting and mypy
dirkgr May 12, 2020
fced896
Merge branch 'master' into Tokenizers
dirkgr May 12, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,9 @@ def sanitize_wordpiece(wordpiece: str) -> str:
if wordpiece.startswith("##"):
return wordpiece[2:]
elif wordpiece.startswith("Ġ"):
return wordpiece[1:]
return wordpiece.replace("Ġ", " ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this method might actually interact poorly with the demo. You'll be adding spaces in places that the demo isn't expecting it. Not sure how much a difference this makes, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huggingface have also changed their tack on this. At some point the argument was that 'Ġ' is just the encoding for space, and the space is part of the token, so I made it this way. But in the latest huggingface, the space is no longer part of the token. I'm hoping for clarification from @mfuntowicz on the matter, but I guess if in doubt, I'll make this match the current behavior.

I don't know if the demo is using sanitize_wordpiece() though. It's not a commonly used call as far as I know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, you're right. I thought sanitize_wordpiece was called by sanitize (which is used in predictors), but it's not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this back.

elif wordpiece.startswith("▁"):
return wordpiece[1:]
return wordpiece.replace("▁", " ")
else:
return wordpiece

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def __init__(
self._tokenizer = self._allennlp_tokenizer.tokenizer
self._added_to_vocabulary = False

self._num_added_start_tokens = self._allennlp_tokenizer.num_added_start_tokens
self._num_added_end_tokens = self._allennlp_tokenizer.num_added_end_tokens
self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens)
self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens)

self._max_length = max_length
if self._max_length is not None:
num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1
self._effective_max_length = ( # we need to take into account special tokens
self._max_length - self._tokenizer.num_added_tokens()
self._max_length - num_added_tokens
)
if self._effective_max_length <= 0:
raise ValueError(
Expand Down
425 changes: 269 additions & 156 deletions allennlp/data/tokenizers/pretrained_transformer_tokenizer.py

Large diffs are not rendered by default.

36 changes: 35 additions & 1 deletion allennlp/data/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
import logging

from allennlp.common import Registrable
Expand Down Expand Up @@ -45,3 +45,37 @@ def tokenize(self, text: str) -> List[Token]:
tokens : `List[Token]`
"""
raise NotImplementedError

def add_special_tokens(
self, tokens1: List[Token], tokens2: Optional[List[Token]] = None
) -> List[Token]:
"""
Adds special tokens to tokenized text. These are tokens like [CLS] or [SEP].

Not all tokenizers do this. The default is to just return the tokens unchanged.

# Parameters

tokens1 : `List[Token]`
The list of tokens to add special tokens to.
tokens2 : `Optional[List[Token]]`
An optional second list of tokens. This will be concatenated with `tokens1`. Special tokens will be
added as appropriate.

# Returns
tokens : `List[Token]`
The combined list of tokens, with special tokens added.
"""
return tokens1 + (tokens2 or [])

def special_tokens_for_sequence(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name makes me expect a return type of List[str], instead of int. Same with below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the name.

"""
Returns the number of special tokens added for a single sequence.
"""
return 0

def special_tokens_for_pair(self) -> int:
"""
Returns the number of special tokens added for a pair of sequences.
"""
return 0
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, model_name: str, max_length: int = None) -> None:
self.output_dim = self.transformer_model.config.hidden_size

tokenizer = PretrainedTransformerTokenizer(model_name)
self._num_added_start_tokens = tokenizer.num_added_start_tokens
self._num_added_end_tokens = tokenizer.num_added_end_tokens
self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens)
self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens)
self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens

@overrides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ def test_as_array_produces_token_sequence_bert_cased(self):
indexed = indexer.tokens_to_indices(allennlp_tokens, vocab)
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_bert_cased_sentence_pair(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need these tests anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are testing tokenize_sentence_pair(), which I removed. I resurrected them to test add_special_tokens().

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
allennlp_tokenizer = PretrainedTransformerTokenizer("bert-base-cased")
indexer = PretrainedTransformerIndexer(model_name="bert-base-cased")
default_format = "[CLS] AllenNLP is great! [SEP] Really it is! [SEP]"
tokens = tokenizer.tokenize(default_format)
expected_ids = tokenizer.convert_tokens_to_ids(tokens)
allennlp_tokens = allennlp_tokenizer.tokenize_sentence_pair(
"AllenNLP is great!", "Really it is!"
)
vocab = Vocabulary()
indexed = indexer.tokens_to_indices(allennlp_tokens, vocab)
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_roberta(self):
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
allennlp_tokenizer = PretrainedTransformerTokenizer("roberta-base")
Expand All @@ -63,20 +49,6 @@ def test_as_array_produces_token_sequence_roberta(self):
indexed = indexer.tokens_to_indices(allennlp_tokens, vocab)
assert indexed["token_ids"] == expected_ids

def test_as_array_produces_token_sequence_roberta_sentence_pair(self):
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
allennlp_tokenizer = PretrainedTransformerTokenizer("roberta-base")
indexer = PretrainedTransformerIndexer(model_name="roberta-base")
default_format = "<s> AllenNLP is great! </s> </s> Really it is! </s>"
tokens = tokenizer.tokenize(default_format)
expected_ids = tokenizer.convert_tokens_to_ids(tokens)
allennlp_tokens = allennlp_tokenizer.tokenize_sentence_pair(
"AllenNLP is great!", "Really it is!"
)
vocab = Vocabulary()
indexed = indexer.tokens_to_indices(allennlp_tokens, vocab)
assert indexed["token_ids"] == expected_ids

def test_transformers_vocab_sizes(self):
def check_vocab_size(model_name: str):
namespace = "tags"
Expand Down
Loading