This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dataset readers for masked language modeling and next-token-language-…
…modeling (#3147) * Adding language modeling readers * Added test * Cleanup * Test passes * NextTokenLm test * doc * Revert accidental changes * Pylint, mypy * doc * mypy * Change todos to runtime errors * pylint...
- Loading branch information
1 parent
1eaa1ff
commit 370d512
Showing
8 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
allennlp/data/dataset_readers/masked_language_modeling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import Dict, List | ||
import logging | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.data.instance import Instance | ||
from allennlp.data.tokenizers.tokenizer import Tokenizer | ||
from allennlp.data.tokenizers import Token, WordTokenizer | ||
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.token_indexers.token_indexer import TokenIndexer | ||
from allennlp.data.fields import IndexField, Field, ListField, TextField | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DatasetReader.register("masked_language_modeling") | ||
class MaskedLanguageModelingReader(DatasetReader): | ||
""" | ||
Reads a text file and converts it into a ``Dataset`` suitable for training a masked language | ||
model. | ||
The :class:`Field` s that we create are the following: an input ``TextField``, a mask position | ||
``ListField[IndexField]``, and a target token ``TextField`` (the target tokens aren't a single | ||
string of text, but we use a ``TextField`` so we can index the target tokens the same way as | ||
our input, typically with a single ``PretrainedTransformerIndexer``). The mask position and | ||
target token lists are the same length. | ||
NOTE: This is not fully functional! It was written to put together a demo for interpreting and | ||
attacking masked language modeling, not for actually training anything. ``text_to_instance`` | ||
is functional, but ``_read`` is not. To make this fully functional, you would want some | ||
sampling strategies for picking the locations for [MASK] tokens, and probably a bunch of | ||
efficiency / multi-processing stuff. | ||
Parameters | ||
---------- | ||
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) | ||
We use this ``Tokenizer`` for the text. See :class:`Tokenizer`. | ||
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) | ||
We use this to define the input representation for the text, and to get ids for the mask | ||
targets. See :class:`TokenIndexer`. | ||
""" | ||
def __init__(self, | ||
tokenizer: Tokenizer = None, | ||
token_indexers: Dict[str, TokenIndexer] = None, | ||
lazy: bool = False) -> None: | ||
super().__init__(lazy) | ||
self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) | ||
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} | ||
|
||
@overrides | ||
def _read(self, file_path: str): | ||
import sys | ||
# You can call pytest with either `pytest` or `py.test`. | ||
if 'test' not in sys.argv[0]: | ||
raise RuntimeError('_read is only implemented for unit tests at the moment') | ||
with open(file_path, "r") as text_file: | ||
for sentence in text_file: | ||
tokens = self._tokenizer.tokenize(sentence) | ||
target = tokens[0].text | ||
tokens[0] = Token('[MASK]') | ||
yield self.text_to_instance(sentence, tokens, [target]) | ||
|
||
@overrides | ||
def text_to_instance(self, # type: ignore | ||
sentence: str = None, | ||
tokens: List[Token] = None, | ||
targets: List[str] = None) -> Instance: | ||
# pylint: disable=arguments-differ | ||
""" | ||
Parameters | ||
---------- | ||
sentence : ``str``, optional | ||
A sentence containing [MASK] tokens that should be filled in by the model. This input | ||
is superceded and ignored if ``tokens`` is given. | ||
tokens : ``List[Token]``, optional | ||
An already-tokenized sentence containing some number of [MASK] tokens to be predicted. | ||
targets : ``List[str]``, optional | ||
Contains the target tokens to be predicted. The length of this list should be the same | ||
as the number of [MASK] tokens in the input. | ||
""" | ||
if not tokens: | ||
tokens = self._tokenizer.tokenize(sentence) | ||
input_field = TextField(tokens, self._token_indexers) | ||
mask_positions = [] | ||
for i, token in enumerate(tokens): | ||
if token.text == '[MASK]': | ||
mask_positions.append(i) | ||
if not mask_positions: | ||
raise ValueError("No [MASK] tokens found!") | ||
if targets and len(targets) != len(mask_positions): | ||
raise ValueError(f"Found {len(mask_positions)} mask tokens and {len(targets)} targets") | ||
mask_position_field = ListField([IndexField(i, input_field) for i in mask_positions]) | ||
# TODO(mattg): there's a problem if the targets get split into multiple word pieces... | ||
fields: Dict[str, Field] = {'tokens': input_field, 'mask_positions': mask_position_field} | ||
if targets is not None: | ||
target_field = TextField([Token(target) for target in targets], self._token_indexers) | ||
fields['target_ids'] = target_field | ||
return Instance(fields) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Dict, List | ||
import logging | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.data.instance import Instance | ||
from allennlp.data.tokenizers.tokenizer import Tokenizer | ||
from allennlp.data.tokenizers import Token, WordTokenizer | ||
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.token_indexers.token_indexer import TokenIndexer | ||
from allennlp.data.fields import Field, TextField | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DatasetReader.register("next_token_lm") | ||
class NextTokenLmReader(DatasetReader): | ||
""" | ||
Creates ``Instances`` suitable for use in predicting a single next token using a language | ||
model. The :class:`Field` s that we create are the following: an input ``TextField`` and a | ||
target token ``TextField`` (we only ver have a single token, but we use a ``TextField`` so we | ||
can index it the same way as our input, typically with a single | ||
``PretrainedTransformerIndexer``). | ||
NOTE: This is not fully functional! It was written to put together a demo for interpreting and | ||
attacking language models, not for actually training anything. It would be a really bad idea | ||
to use this setup for training language models, as it would be incredibly inefficient. The | ||
only purpose of this class is for a demo. | ||
Parameters | ||
---------- | ||
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) | ||
We use this ``Tokenizer`` for the text. See :class:`Tokenizer`. | ||
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) | ||
We use this to define the input representation for the text, and to get ids for the mask | ||
targets. See :class:`TokenIndexer`. | ||
""" | ||
def __init__(self, | ||
tokenizer: Tokenizer = None, | ||
token_indexers: Dict[str, TokenIndexer] = None, | ||
lazy: bool = False) -> None: | ||
super().__init__(lazy) | ||
self._tokenizer = tokenizer or WordTokenizer(word_splitter=JustSpacesWordSplitter()) | ||
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} | ||
|
||
@overrides | ||
def _read(self, file_path: str): | ||
import sys | ||
# You can call pytest with either `pytest` or `py.test`. | ||
if 'test' not in sys.argv[0]: | ||
raise RuntimeError('_read is only implemented for unit tests. You should not actually ' | ||
'try to train or evaluate a language model with this code.') | ||
with open(file_path, "r") as text_file: | ||
for sentence in text_file: | ||
tokens = self._tokenizer.tokenize(sentence) | ||
target = 'the' | ||
yield self.text_to_instance(sentence, tokens, target) | ||
|
||
@overrides | ||
def text_to_instance(self, # type: ignore | ||
sentence: str = None, | ||
tokens: List[Token] = None, | ||
target: str = None) -> Instance: | ||
# pylint: disable=arguments-differ | ||
if not tokens: | ||
tokens = self._tokenizer.tokenize(sentence) | ||
input_field = TextField(tokens, self._token_indexers) | ||
fields: Dict[str, Field] = {'tokens': input_field} | ||
if target: | ||
fields['target_ids'] = TextField([Token(target)], self._token_indexers) | ||
return Instance(fields) |
46 changes: 46 additions & 0 deletions
46
allennlp/tests/data/dataset_readers/masked_language_modeling_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.data.dataset_readers import MaskedLanguageModelingReader | ||
from allennlp.data import Vocabulary | ||
from allennlp.data.tokenizers import PretrainedTransformerTokenizer | ||
from allennlp.data.token_indexers import PretrainedTransformerIndexer | ||
|
||
class TestMaskedLanguageModelingDatasetReader(AllenNlpTestCase): | ||
def test_text_to_instance_with_basic_tokenizer_and_indexer(self): | ||
reader = MaskedLanguageModelingReader() | ||
|
||
vocab = Vocabulary() | ||
vocab.add_tokens_to_namespace(['This', 'is', 'a', '[MASK]', 'token', '.'], 'tokens') | ||
|
||
instance = reader.text_to_instance(sentence='This is a [MASK] token .', targets=['This']) | ||
assert [t.text for t in instance['tokens']] == ['This', 'is', 'a', '[MASK]', 'token', '.'] | ||
assert [i.sequence_index for i in instance['mask_positions']] == [3] | ||
assert [t.text for t in instance['target_ids']] == ['This'] | ||
|
||
instance.index_fields(vocab) | ||
tensor_dict = instance.as_tensor_dict(instance.get_padding_lengths()) | ||
assert tensor_dict.keys() == {'tokens', 'mask_positions', 'target_ids'} | ||
assert tensor_dict['tokens']['tokens'].numpy().tolist() == [2, 3, 4, 5, 6, 7] | ||
assert tensor_dict['target_ids']['tokens'].numpy().tolist() == [2] | ||
assert tensor_dict['mask_positions'].numpy().tolist() == [[3]] | ||
|
||
def test_text_to_instance_with_bert_tokenizer_and_indexer(self): | ||
tokenizer = PretrainedTransformerTokenizer('bert-base-cased', do_lowercase=False) | ||
token_indexer = PretrainedTransformerIndexer('bert-base-cased', do_lowercase=False) | ||
reader = MaskedLanguageModelingReader(tokenizer, {'bert': token_indexer}) | ||
instance = reader.text_to_instance(sentence='This is AllenNLP [MASK] token .', | ||
targets=['This']) | ||
assert [t.text for t in instance['tokens']] == ['[CLS]', 'This', 'is', 'Allen', '##NL', | ||
'##P', '[MASK]', 'token', '.', '[SEP]'] | ||
assert [i.sequence_index for i in instance['mask_positions']] == [6] | ||
assert [t.text for t in instance['target_ids']] == ['This'] | ||
|
||
vocab = Vocabulary() | ||
instance.index_fields(vocab) | ||
tensor_dict = instance.as_tensor_dict(instance.get_padding_lengths()) | ||
assert tensor_dict.keys() == {'tokens', 'mask_positions', 'target_ids'} | ||
bert_token_ids = tensor_dict['tokens']['bert'].numpy().tolist() | ||
target_ids = tensor_dict['target_ids']['bert'].numpy().tolist() | ||
# I don't know what wordpiece id BERT is going to assign to 'This', but it at least should | ||
# be the same between the input and the target. | ||
assert target_ids[0] == bert_token_ids[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# pylint: disable=no-self-use,invalid-name | ||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.data.dataset_readers import NextTokenLmReader | ||
from allennlp.data import Vocabulary | ||
from allennlp.data.tokenizers import PretrainedTransformerTokenizer | ||
from allennlp.data.token_indexers import PretrainedTransformerIndexer | ||
|
||
class TestNextTokenLmReader(AllenNlpTestCase): | ||
def test_text_to_instance_with_basic_tokenizer_and_indexer(self): | ||
reader = NextTokenLmReader() | ||
|
||
vocab = Vocabulary() | ||
vocab.add_tokens_to_namespace(['This', 'is', 'a'], 'tokens') | ||
|
||
instance = reader.text_to_instance(sentence='This is a', target='This') | ||
assert [t.text for t in instance['tokens']] == ['This', 'is', 'a'] | ||
assert [t.text for t in instance['target_ids']] == ['This'] | ||
|
||
instance.index_fields(vocab) | ||
tensor_dict = instance.as_tensor_dict(instance.get_padding_lengths()) | ||
assert tensor_dict.keys() == {'tokens', 'target_ids'} | ||
assert tensor_dict['tokens']['tokens'].numpy().tolist() == [2, 3, 4] | ||
assert tensor_dict['target_ids']['tokens'].numpy().tolist() == [2] | ||
|
||
def test_text_to_instance_with_bert_tokenizer_and_indexer(self): | ||
tokenizer = PretrainedTransformerTokenizer('bert-base-cased', do_lowercase=False) | ||
token_indexer = PretrainedTransformerIndexer('bert-base-cased', do_lowercase=False) | ||
reader = NextTokenLmReader(tokenizer, {'bert': token_indexer}) | ||
instance = reader.text_to_instance(sentence='AllenNLP is very', | ||
target='very') | ||
assert [t.text for t in instance['tokens']] == ['[CLS]', 'Allen', '##NL', '##P', 'is', | ||
'very', '[SEP]'] | ||
assert [t.text for t in instance['target_ids']] == ['very'] | ||
|
||
vocab = Vocabulary() | ||
instance.index_fields(vocab) | ||
tensor_dict = instance.as_tensor_dict(instance.get_padding_lengths()) | ||
assert tensor_dict.keys() == {'tokens', 'target_ids'} | ||
bert_token_ids = tensor_dict['tokens']['bert'].numpy().tolist() | ||
target_ids = tensor_dict['target_ids']['bert'].numpy().tolist() | ||
# I don't know what wordpiece id BERT is going to assign to 'This', but it at least should | ||
# be the same between the input and the target. | ||
assert target_ids[0] == bert_token_ids[5] |
7 changes: 7 additions & 0 deletions
7
doc/api/allennlp.data.dataset_readers.masked_language_modeling.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
allennlp.data.dataset_readers.masked_language_modeling | ||
====================================================== | ||
|
||
.. automodule:: allennlp.data.dataset_readers.masked_language_modeling | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
allennlp.data.dataset_readers.next_token_lm | ||
=========================================== | ||
|
||
.. automodule:: allennlp.data.dataset_readers.next_token_lm | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters