-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[WIP] Language Modeling of Contiguous Text #2414
Changes from all commits
f43a1cc
3cb61fb
c46f798
dbb516e
a5762af
94067f4
71c83c9
8752dcf
2e199d3
680bef7
e8151f5
5167d1a
6161bdf
4bfce46
6a6cbfd
492239d
7b44053
2749acc
638d06f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,160 @@ | ||
from typing import Dict | ||
from typing import Dict, List | ||
import logging | ||
|
||
from overrides import overrides | ||
import numpy | ||
|
||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.common.file_utils import cached_path | ||
from allennlp.common.tqdm import Tqdm | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.tokenizers.tokenizer import Tokenizer | ||
from allennlp.data.tokenizers import WordTokenizer | ||
from allennlp.data.tokenizers import Token, WordTokenizer | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.token_indexers.token_indexer import TokenIndexer | ||
from allennlp.data.fields import TextField | ||
from allennlp.data.fields import ListField, TextField | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer | ||
|
||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DatasetReader.register("language_modeling") | ||
class LanguageModelingReader(DatasetReader): | ||
""" | ||
Reads a text file and converts it into a ``Dataset`` suitable for training a language model. | ||
|
||
Note that there's one issue that needs to be fixed before this is actually usable for language | ||
modeling - the way start and end tokens for sentences are handled is not correct; we need to | ||
add a sentence splitter before this will be done right. | ||
Reads a text file and converts it into a ``Dataset`` suitable for training | ||
a language model. | ||
|
||
Parameters | ||
---------- | ||
tokens_per_instance : ``int``, optional (default=``None``) | ||
If this is ``None``, we will have each training instance be a single sentence. If this is | ||
not ``None``, we will instead take all sentences, including their start and stop tokens, | ||
line them up, and split the tokens into groups of this number, for more efficient training | ||
of the language model. | ||
batch_size : ``int``, optional (default=``20``) | ||
Batch size to use in language modeling. | ||
truncated_bptt_size : ``int``, optional (default=``35``) | ||
The sequence length to use for truncated backpropagation through time. | ||
fuzz_truncated_bptt_size : ``bool``, optional (default=``True``) | ||
If True, randomly perturb the truncated_bptt_size between batches. | ||
bidirectional : ``bool``, optional (default=``False``) | ||
If True, generate instances for bidirectional language modeling. | ||
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) | ||
We use this ``Tokenizer`` for the text. See :class:`Tokenizer`. | ||
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) | ||
We use this to define the input representation for the text. See :class:`TokenIndexer`. | ||
Note that the `output` representation will always be single token IDs - if you've specified | ||
a ``SingleIdTokenIndexer`` here, we use the first one you specify. Otherwise, we create | ||
one with default parameters. | ||
We use this to define the input representation for the text. | ||
See :class:`TokenIndexer`. | ||
start_tokens : ``List[str]``, optional (default=``None``) | ||
These are prepended to each line read by the dataset reader. | ||
end_tokens : ``List[str]``, optional (default=``["</S>"]``) | ||
These are appended to each line read by the dataset reader. | ||
""" | ||
def __init__(self, | ||
tokens_per_instance: int = None, | ||
batch_size: int = 20, | ||
truncated_bptt_size: int = 35, | ||
fuzz_truncated_bptt_size: bool = True, | ||
bidirectional: bool = False, | ||
tokenizer: Tokenizer = None, | ||
token_indexers: Dict[str, TokenIndexer] = None, | ||
lazy: bool = False) -> None: | ||
super().__init__(lazy) | ||
start_tokens: List[str] = None, | ||
end_tokens: List[str] = ["</S>"]) -> None: | ||
super().__init__(lazy=False) | ||
self._batch_size = batch_size | ||
if truncated_bptt_size < 2: | ||
raise ConfigurationError("truncated_bptt_size cannot be less than 2.") | ||
self._truncated_bptt_size = truncated_bptt_size | ||
self._fuzz_truncated_bptt_size = fuzz_truncated_bptt_size | ||
self._bidirectional = bidirectional | ||
self._tokenizer = tokenizer or WordTokenizer() | ||
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} | ||
self._tokens_per_instance = tokens_per_instance | ||
|
||
# No matter how you want to represent the input, we'll always represent the output as a | ||
# single token id. This code lets you learn a language model that concatenates word | ||
# embeddings with character-level encoders, in order to predict the word token that comes | ||
# next. | ||
self._output_indexer: Dict[str, TokenIndexer] = None | ||
for name, indexer in self._token_indexers.items(): | ||
if isinstance(indexer, SingleIdTokenIndexer): | ||
self._output_indexer = {name: indexer} | ||
break | ||
else: | ||
self._output_indexer = {"tokens": SingleIdTokenIndexer()} | ||
self._start_tokens = [Token(st) for st in (start_tokens or [])] | ||
self._end_tokens = [Token(et) for et in (end_tokens or [])] | ||
# Cache the batched tokens we read from data files. | ||
self._all_batched_file_tokens = {} | ||
|
||
@overrides | ||
def _read(self, file_path: str): | ||
# if `file_path` is a URL, redirect to the cache | ||
file_path = cached_path(file_path) | ||
if file_path not in self._all_batched_file_tokens: | ||
logger.info('Loading data from %s', file_path) | ||
# if `file_path` is a URL, redirect to the cache | ||
file_path = cached_path(file_path) | ||
|
||
with open(file_path, "r") as text_file: | ||
instance_strings = text_file.readlines() | ||
# Read the contents of the file into one long list of tokens, | ||
# adding start and/or end tokens as necessary. | ||
file_tokens = [] | ||
file_tokens.extend(self._start_tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be in the forloop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, you're correct (sorry I haven't looked at this PR in awhile). In general, though, language modeling of contiguous text doesn't use any start tokens and only uses eos tokens (e.g., see /~https://github.com/salesforce/awd-lstm-lm/blob/master/data.py#L34-L54 ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, although AFAIK ELMo actually used start tokens (/~https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md#notes-on-statefulness-and-non-determinism) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but this dataset reader is not for training a LM like ELMo. ELMo is trained on shuffled sentences, hence why it needs a start token for each sentence. This dataset reader is for contiguous text (e.g., books), like the corpora that the OpenAI GPT was trained on / the PTB LM benchmark most folks use. In this setting, people don't typically use start tokens. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, got it, thanks -- I assumed that stateful (RNN) language models were always trained in contiguous fashion. |
||
with open(file_path, "r") as text_file: | ||
for line in text_file: | ||
tokenized_line = self._tokenizer.tokenize(line) | ||
file_tokens.extend(tokenized_line) | ||
file_tokens.extend(self._end_tokens) | ||
|
||
if self._tokens_per_instance is not None: | ||
all_text = " ".join([x.replace("\n", " ").strip() for x in instance_strings]) | ||
tokenized_text = self._tokenizer.tokenize(all_text) | ||
num_tokens = self._tokens_per_instance + 1 | ||
tokenized_strings = [] | ||
logger.info("Creating dataset from all text in file: %s", file_path) | ||
for index in Tqdm.tqdm(range(0, len(tokenized_text) - num_tokens, num_tokens - 1)): | ||
tokenized_strings.append(tokenized_text[index:(index + num_tokens)]) | ||
# Divide file_tokens into batch_size lists | ||
# Work out how we can evenly split the dataset into batch_size parts | ||
total_num_tokens_per_batch = len(file_tokens) // self._batch_size | ||
if total_num_tokens_per_batch == 0: | ||
# TODO (nfliu): figure out if this is the desired behavior | ||
raise ValueError(f"There are {len(file_tokens)} tokens in the file, " | ||
f"but batch size is {self._batch_size}. " | ||
"batch size must be less than or equal to number of " | ||
"tokens in the file.") | ||
|
||
# Trim off the remainder from file_tokens, so we can evenly divide it | ||
# into batch_size lists. | ||
file_tokens_for_even_split = file_tokens[:total_num_tokens_per_batch * | ||
self._batch_size] | ||
# Evenly divide the data into batch_size lists. | ||
batched_file_tokens = [ | ||
file_tokens_for_even_split[i:i + total_num_tokens_per_batch] for i in | ||
range(0, len(file_tokens_for_even_split), total_num_tokens_per_batch)] | ||
# Cache the tokens of the dataset we've just read. | ||
self._all_batched_file_tokens[file_path] = batched_file_tokens | ||
else: | ||
tokenized_strings = [self._tokenizer.tokenize(s) for s in instance_strings] | ||
batched_file_tokens = self._all_batched_file_tokens[file_path] | ||
|
||
for tokenized_string in tokenized_strings: | ||
input_field = TextField(tokenized_string[:-1], self._token_indexers) | ||
output_field = TextField(tokenized_string[1:], self._output_indexer) | ||
yield Instance({'input_tokens': input_field, | ||
'output_tokens': output_field}) | ||
# Iterate over the batched_file_tokens, yielding batches | ||
# If bidirectional, we start at index 1 (so the first instance) has | ||
# backward targets. Else, we start at index 0. | ||
batch_start_index = 1 if self._bidirectional else 0 | ||
# The max value of batch_start_index is len(batched_file_tokens[0]) - 2, | ||
# leaving room for the target even when the final batch is size 1. | ||
while batch_start_index < len(batched_file_tokens[0]) - 1: | ||
if self._fuzz_truncated_bptt_size: | ||
# This randomization is taken from the code for training the AWD-LSTM. | ||
# (matrices of size (batch_size, truncated_bptt_size)) | ||
fuzzy_truncated_bptt_size = ( | ||
self._truncated_bptt_size if numpy.random.random() < 0.95 else | ||
self._truncated_bptt_size / 2.) | ||
# Prevent excessively small or negative sequence length | ||
sequence_length = max(5, | ||
int(numpy.random.normal(fuzzy_truncated_bptt_size, 5))) | ||
# There's a very small chance that it could select a very long sequence | ||
# length, resulting in OOM. So we cap it at no more than | ||
# self._truncated_bptt_size + 10 | ||
sequence_length = min(sequence_length, self._truncated_bptt_size + 10) | ||
else: | ||
sequence_length = self._truncated_bptt_size | ||
|
||
@overrides | ||
def text_to_instance(self, sentence: str) -> Instance: # type: ignore | ||
# pylint: disable=arguments-differ | ||
tokenized_string = self._tokenizer.tokenize(sentence) | ||
input_field = TextField(tokenized_string[:-1], self._token_indexers) | ||
output_field = TextField(tokenized_string[1:], self._output_indexer) | ||
return Instance({'input_tokens': input_field, 'output_tokens': output_field}) | ||
# We need to constrain the sequence_length to ensure that | ||
# the forward targets don't reach beyond the length of our dataset | ||
sequence_length = min(sequence_length, | ||
len(batched_file_tokens[0]) - batch_start_index - 1) | ||
batch_inputs = [single_batch[batch_start_index:batch_start_index + sequence_length] | ||
for single_batch in batched_file_tokens] | ||
batch_forward_targets = [single_batch[batch_start_index + 1:batch_start_index + 1 + sequence_length] | ||
for single_batch in batched_file_tokens] | ||
input_field = ListField([TextField(single_batch, self._token_indexers) for | ||
single_batch in batch_inputs]) | ||
forward_targets_field = ListField([TextField(single_batch, self._token_indexers) for | ||
single_batch in batch_forward_targets]) | ||
if not self._bidirectional: | ||
yield Instance({ | ||
"inputs": input_field, | ||
"forward_targets": forward_targets_field | ||
}) | ||
else: | ||
batch_backward_targets = [single_batch[batch_start_index - 1:batch_start_index - 1 + sequence_length] | ||
for single_batch in batched_file_tokens] | ||
backward_targets_field = ListField([TextField(single_batch, self._token_indexers) for | ||
single_batch in batch_backward_targets]) | ||
yield Instance({ | ||
"inputs": input_field, | ||
"forward_targets": forward_targets_field, | ||
"backward_targets": backward_targets_field | ||
}) | ||
batch_start_index += sequence_length |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Iterable, Iterator | ||
import logging | ||
|
||
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.basic_iterator import BasicIterator | ||
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict | ||
from allennlp.data.dataset import Batch | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DataIterator.register("language_modeling") | ||
class LanguageModelingIterator(BasicIterator): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to give this object a more general name - it is useful for any problem where batching is done in the |
||
""" | ||
An iterator used for language modeling of contiguous text. | ||
This is essentially the same as the BasicIterator, but shuffling | ||
is turned off, the batch size is set to 1, and maximum_samples_per_batch | ||
is not set. | ||
|
||
Parameters | ||
---------- | ||
instances_per_epoch : ``int``, optional, (default = None) | ||
If specified, each epoch will consist of precisely this many instances. | ||
If not specified, each epoch will consist of a single pass through the dataset. | ||
max_instances_in_memory : ``int``, optional, (default = None) | ||
If specified, the iterator will load this many instances at a time into an | ||
in-memory list and then produce batches from one such list at a time. This | ||
could be useful if your instances are read lazily from disk. | ||
cache_instances : ``bool``, optional, (default = False) | ||
If true, the iterator will cache the tensorized instances in memory. | ||
If false, it will do the tensorization anew each iteration. | ||
track_epoch : ``bool``, optional, (default = False) | ||
If true, each instance will get a ``MetadataField`` containing the epoch number. | ||
""" | ||
|
||
def __init__(self, | ||
instances_per_epoch: int = None, | ||
max_instances_in_memory: int = None, | ||
cache_instances: bool = False, | ||
track_epoch: bool = False) -> None: | ||
super().__init__(batch_size=1, | ||
instances_per_epoch=instances_per_epoch, | ||
max_instances_in_memory=max_instances_in_memory, | ||
cache_instances=cache_instances, | ||
track_epoch=track_epoch, | ||
maximum_samples_per_batch=None) | ||
|
||
def __call__(self, | ||
instances: Iterable[Instance], | ||
num_epochs: int = None, | ||
shuffle: bool = True) -> Iterator[TensorDict]: | ||
# Set shuffle to False it is True | ||
# TODO (nfliu): verify this is the right thing to do here. | ||
if shuffle: | ||
logger.info("LanguageModelingIterator does not shuffle instances.") | ||
shuffle = False | ||
for tensor_dict in super().__call__(instances=instances, | ||
num_epochs=num_epochs, | ||
shuffle=shuffle): | ||
# Remove singleton dimensions from tensor dict produced | ||
# by instances generated by LanguageModelingReader | ||
for token_level in tensor_dict.get("inputs", {}): | ||
tensor_dict["inputs"][token_level] = tensor_dict["inputs"][token_level].squeeze(0) | ||
for token_level in tensor_dict.get("forward_targets", {}): | ||
tensor_dict["forward_targets"][token_level] = tensor_dict[ | ||
"forward_targets"][token_level].squeeze(0) | ||
for token_level in tensor_dict.get("backward_targets", {}): | ||
tensor_dict["backward_targets"][token_level] = tensor_dict[ | ||
"backward_targets"][token_level].squeeze(0) | ||
yield tensor_dict | ||
|
||
|
||
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: | ||
# Set shuffle to False it is True | ||
if shuffle: | ||
logger.info("LanguageModelingIterator does not shuffle instances.") | ||
shuffle = False | ||
yield from super()._create_batches(instances=instances, | ||
shuffle=shuffle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think setting
lazy=False
defeats the purpose of usingfuzz_truncated_bppt_size=True
during training. The point of using random sequence lengths is to prevent batches ending with same tokens each epoch, however this will not happen iflazy=False
since_read
will only be called once. It probably makes more sense to let users setlazy
themselves and raise an error iffuzz_truncated_bppt_size=True
andlazy=False
.