Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding a PretrainedTransformerTokenizer (#3145)
Browse files Browse the repository at this point in the history
* Adding a PretrainedTransformerTokenizer

* pylint

* doc
  • Loading branch information
matt-gardner authored Aug 13, 2019
1 parent f9e2029 commit 217022f
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 6 deletions.
1 change: 1 addition & 0 deletions allennlp/data/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

from allennlp.data.tokenizers.tokenizer import Token, Tokenizer
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers.sentence_splitter import SentenceSplitter
4 changes: 0 additions & 4 deletions allennlp/data/tokenizers/character_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ def __init__(self,
self._start_tokens.reverse()
self._end_tokens = end_tokens or []

@overrides
def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
return [self.tokenize(text) for text in texts]

@overrides
def tokenize(self, text: str) -> List[Token]:
if self._lowercase_characters:
Expand Down
64 changes: 64 additions & 0 deletions allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
from typing import List, Tuple

from overrides import overrides
from pytorch_transformers.tokenization_auto import AutoTokenizer

from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer

logger = logging.getLogger(__name__)


@Tokenizer.register("pretrained_transformer")
class PretrainedTransformerTokenizer(Tokenizer):
"""
A ``PretrainedTransformerTokenizer`` uses a model from HuggingFace's
``pytorch_transformers`` library to tokenize some input text. This often means wordpieces
(where ``'AllenNLP is awesome'`` might get split into ``['Allen', '##NL', '##P', 'is',
'awesome']``), but it could also use byte-pair encoding, or some other tokenization, depending
on the pretrained model that you're using.
We take a model name as an input parameter, which we will pass to
``AutoTokenizer.from_pretrained``.
Parameters
----------
model_name : ``str``
The name of the pretrained wordpiece tokenizer to use.
start_tokens : ``List[str]``, optional
If given, these tokens will be added to the beginning of every string we tokenize. We try
to be a little bit smart about defaults here - e.g., if your model name contains ``bert``,
we by default add ``[CLS]`` at the beginning and ``[SEP]`` at the end.
end_tokens : ``List[str]``, optional
If given, these tokens will be added to the end of every string we tokenize.
"""
def __init__(self,
model_name: str,
do_lowercase: bool,
start_tokens: List[str] = None,
end_tokens: List[str] = None) -> None:
if model_name.endswith("-cased") and do_lowercase:
logger.warning("Your pretrained model appears to be cased, "
"but your tokenizer is lowercasing tokens.")
elif model_name.endswith("-uncased") and not do_lowercase:
logger.warning("Your pretrained model appears to be uncased, "
"but your tokenizer is not lowercasing tokens.")
self._tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lowercase)
default_start_tokens, default_end_tokens = _guess_start_and_end_token_defaults(model_name)
self._start_tokens = start_tokens if start_tokens is not None else default_start_tokens
self._end_tokens = end_tokens if end_tokens is not None else default_end_tokens

@overrides
def tokenize(self, text: str) -> List[Token]:
# TODO(mattg): track character offsets. Might be too challenging to do it here, given that
# pytorch-transformers is dealing with the whitespace...
token_strings = self._start_tokens + self._tokenizer.tokenize(text) + self._end_tokens
return [Token(t) for t in token_strings]


def _guess_start_and_end_token_defaults(model_name: str) -> Tuple[List[str], List[str]]:
if 'bert' in model_name:
return (['[CLS]'], ['[SEP]'])
else:
return ([], [])
5 changes: 4 additions & 1 deletion allennlp/data/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
"""
Batches together tokenization of several texts, in case that is faster for particular
tokenizers.
By default we just do this without batching. Override this in your tokenizer if you have a
good way of doing batched computation.
"""
raise NotImplementedError
return [self.tokenize(text) for text in texts]

def tokenize(self, text: str) -> List[Token]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# pylint: disable=no-self-use,invalid-name

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.tokenizers import PretrainedTransformerTokenizer

class TestPretrainedTransformerTokenizer(AllenNlpTestCase):
def test_splits_into_wordpieces(self):
tokenizer = PretrainedTransformerTokenizer('bert-base-cased', do_lowercase=False)
sentence = "A, [MASK] AllenNLP sentence."
tokens = [t.text for t in tokenizer.tokenize(sentence)]
expected_tokens = ["[CLS]", "A", ",", "[MASK]", "Allen", "##NL", "##P", "sentence", ".", "[SEP]"]
assert tokens == expected_tokens
9 changes: 8 additions & 1 deletion doc/api/allennlp.data.tokenizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ allennlp.data.tokenizers
* :ref:`Tokenizer<tokenizer>`
* :ref:`WordTokenizer<word-tokenizer>`
* :ref:`CharacterTokenizer<character-tokenizer>`
* :ref:`PretrainedTransformerTokenizer<pretrained-transformer-tokenizer>`
* :ref:`WordFilter<word-filter>`
* :ref:`WordSplitter<word-splitter>`
* :ref:`WordStemmer<word-stemmer>`
Expand All @@ -36,6 +37,12 @@ allennlp.data.tokenizers
:undoc-members:
:show-inheritance:

.. _pretrained-transformer-tokenizer:
.. automodule:: allennlp.data.tokenizers.pretrained_transformer_tokenizer
:members:
:undoc-members:
:show-inheritance:

.. _word-filter:
.. automodule:: allennlp.data.tokenizers.word_filter
:members:
Expand All @@ -58,4 +65,4 @@ allennlp.data.tokenizers
.. automodule:: allennlp.data.tokenizers.sentence_splitter
:members:
:undoc-members:
:show-inheritance:
:show-inheritance:
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ word2number>=1.1

# To use the BERT model
pytorch-pretrained-bert>=0.6.0
git+git://github.com/huggingface/pytorch-transformers.git@a7b4cfe9194bf93c7044a42c9f1281260ce6279e

# For caching processed data
jsonpickle
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
'sqlparse>=0.2.4',
'word2number>=1.1',
'pytorch-pretrained-bert>=0.6.0',
'pytorch-transformers @ https://api.github.com/repos/huggingface/pytorch-transformers/tarball/a7b4cfe9194bf93c7044a42c9f1281260ce6279e',
'jsonpickle',
],
entry_points={
Expand Down

0 comments on commit 217022f

Please sign in to comment.