Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

[WIP] Language Modeling of Contiguous Text #2414

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 122 additions & 59 deletions allennlp/data/dataset_readers/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,160 @@
from typing import Dict
from typing import Dict, List
import logging

from overrides import overrides
import numpy

from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.data.tokenizers import Token, WordTokenizer
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.fields import TextField
from allennlp.data.fields import ListField, TextField
from allennlp.data.token_indexers import SingleIdTokenIndexer


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@DatasetReader.register("language_modeling")
class LanguageModelingReader(DatasetReader):
"""
Reads a text file and converts it into a ``Dataset`` suitable for training a language model.

Note that there's one issue that needs to be fixed before this is actually usable for language
modeling - the way start and end tokens for sentences are handled is not correct; we need to
add a sentence splitter before this will be done right.
Reads a text file and converts it into a ``Dataset`` suitable for training
a language model.

Parameters
----------
tokens_per_instance : ``int``, optional (default=``None``)
If this is ``None``, we will have each training instance be a single sentence. If this is
not ``None``, we will instead take all sentences, including their start and stop tokens,
line them up, and split the tokens into groups of this number, for more efficient training
of the language model.
batch_size : ``int``, optional (default=``20``)
Batch size to use in language modeling.
truncated_bptt_size : ``int``, optional (default=``35``)
The sequence length to use for truncated backpropagation through time.
fuzz_truncated_bptt_size : ``bool``, optional (default=``True``)
If True, randomly perturb the truncated_bptt_size between batches.
bidirectional : ``bool``, optional (default=``False``)
If True, generate instances for bidirectional language modeling.
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
We use this ``Tokenizer`` for the text. See :class:`Tokenizer`.
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
We use this to define the input representation for the text. See :class:`TokenIndexer`.
Note that the `output` representation will always be single token IDs - if you've specified
a ``SingleIdTokenIndexer`` here, we use the first one you specify. Otherwise, we create
one with default parameters.
We use this to define the input representation for the text.
See :class:`TokenIndexer`.
start_tokens : ``List[str]``, optional (default=``None``)
These are prepended to each line read by the dataset reader.
end_tokens : ``List[str]``, optional (default=``["</S>"]``)
These are appended to each line read by the dataset reader.
"""
def __init__(self,
tokens_per_instance: int = None,
batch_size: int = 20,
truncated_bptt_size: int = 35,
fuzz_truncated_bptt_size: bool = True,
bidirectional: bool = False,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
lazy: bool = False) -> None:
super().__init__(lazy)
start_tokens: List[str] = None,
end_tokens: List[str] = ["</S>"]) -> None:
super().__init__(lazy=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think setting lazy=False defeats the purpose of using fuzz_truncated_bppt_size=True during training. The point of using random sequence lengths is to prevent batches ending with same tokens each epoch, however this will not happen if lazy=False since _read will only be called once. It probably makes more sense to let users set lazy themselves and raise an error if fuzz_truncated_bppt_size=True and lazy=False.

self._batch_size = batch_size
if truncated_bptt_size < 2:
raise ConfigurationError("truncated_bptt_size cannot be less than 2.")
self._truncated_bptt_size = truncated_bptt_size
self._fuzz_truncated_bptt_size = fuzz_truncated_bptt_size
self._bidirectional = bidirectional
self._tokenizer = tokenizer or WordTokenizer()
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._tokens_per_instance = tokens_per_instance

# No matter how you want to represent the input, we'll always represent the output as a
# single token id. This code lets you learn a language model that concatenates word
# embeddings with character-level encoders, in order to predict the word token that comes
# next.
self._output_indexer: Dict[str, TokenIndexer] = None
for name, indexer in self._token_indexers.items():
if isinstance(indexer, SingleIdTokenIndexer):
self._output_indexer = {name: indexer}
break
else:
self._output_indexer = {"tokens": SingleIdTokenIndexer()}
self._start_tokens = [Token(st) for st in (start_tokens or [])]
self._end_tokens = [Token(et) for et in (end_tokens or [])]
# Cache the batched tokens we read from data files.
self._all_batched_file_tokens = {}

@overrides
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
if file_path not in self._all_batched_file_tokens:
logger.info('Loading data from %s', file_path)
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)

with open(file_path, "r") as text_file:
instance_strings = text_file.readlines()
# Read the contents of the file into one long list of tokens,
# adding start and/or end tokens as necessary.
file_tokens = []
file_tokens.extend(self._start_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be in the forloop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, you're correct (sorry I haven't looked at this PR in awhile). In general, though, language modeling of contiguous text doesn't use any start tokens and only uses eos tokens (e.g., see /~https://github.com/salesforce/awd-lstm-lm/blob/master/data.py#L34-L54 )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but this dataset reader is not for training a LM like ELMo. ELMo is trained on shuffled sentences, hence why it needs a start token for each sentence. This dataset reader is for contiguous text (e.g., books), like the corpora that the OpenAI GPT was trained on / the PTB LM benchmark most folks use. In this setting, people don't typically use start tokens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it, thanks -- I assumed that stateful (RNN) language models were always trained in contiguous fashion.
The shuffling bit might help generalization performance for sentence-level tasks (where the model doesn't learn to rely on an initialized hidden state), but I'm wondering how the same model with and without contiguous training would perform in transfer learning.
Thanks!

with open(file_path, "r") as text_file:
for line in text_file:
tokenized_line = self._tokenizer.tokenize(line)
file_tokens.extend(tokenized_line)
file_tokens.extend(self._end_tokens)

if self._tokens_per_instance is not None:
all_text = " ".join([x.replace("\n", " ").strip() for x in instance_strings])
tokenized_text = self._tokenizer.tokenize(all_text)
num_tokens = self._tokens_per_instance + 1
tokenized_strings = []
logger.info("Creating dataset from all text in file: %s", file_path)
for index in Tqdm.tqdm(range(0, len(tokenized_text) - num_tokens, num_tokens - 1)):
tokenized_strings.append(tokenized_text[index:(index + num_tokens)])
# Divide file_tokens into batch_size lists
# Work out how we can evenly split the dataset into batch_size parts
total_num_tokens_per_batch = len(file_tokens) // self._batch_size
if total_num_tokens_per_batch == 0:
# TODO (nfliu): figure out if this is the desired behavior
raise ValueError(f"There are {len(file_tokens)} tokens in the file, "
f"but batch size is {self._batch_size}. "
"batch size must be less than or equal to number of "
"tokens in the file.")

# Trim off the remainder from file_tokens, so we can evenly divide it
# into batch_size lists.
file_tokens_for_even_split = file_tokens[:total_num_tokens_per_batch *
self._batch_size]
# Evenly divide the data into batch_size lists.
batched_file_tokens = [
file_tokens_for_even_split[i:i + total_num_tokens_per_batch] for i in
range(0, len(file_tokens_for_even_split), total_num_tokens_per_batch)]
# Cache the tokens of the dataset we've just read.
self._all_batched_file_tokens[file_path] = batched_file_tokens
else:
tokenized_strings = [self._tokenizer.tokenize(s) for s in instance_strings]
batched_file_tokens = self._all_batched_file_tokens[file_path]

for tokenized_string in tokenized_strings:
input_field = TextField(tokenized_string[:-1], self._token_indexers)
output_field = TextField(tokenized_string[1:], self._output_indexer)
yield Instance({'input_tokens': input_field,
'output_tokens': output_field})
# Iterate over the batched_file_tokens, yielding batches
# If bidirectional, we start at index 1 (so the first instance) has
# backward targets. Else, we start at index 0.
batch_start_index = 1 if self._bidirectional else 0
# The max value of batch_start_index is len(batched_file_tokens[0]) - 2,
# leaving room for the target even when the final batch is size 1.
while batch_start_index < len(batched_file_tokens[0]) - 1:
if self._fuzz_truncated_bptt_size:
# This randomization is taken from the code for training the AWD-LSTM.
# (matrices of size (batch_size, truncated_bptt_size))
fuzzy_truncated_bptt_size = (
self._truncated_bptt_size if numpy.random.random() < 0.95 else
self._truncated_bptt_size / 2.)
# Prevent excessively small or negative sequence length
sequence_length = max(5,
int(numpy.random.normal(fuzzy_truncated_bptt_size, 5)))
# There's a very small chance that it could select a very long sequence
# length, resulting in OOM. So we cap it at no more than
# self._truncated_bptt_size + 10
sequence_length = min(sequence_length, self._truncated_bptt_size + 10)
else:
sequence_length = self._truncated_bptt_size

@overrides
def text_to_instance(self, sentence: str) -> Instance: # type: ignore
# pylint: disable=arguments-differ
tokenized_string = self._tokenizer.tokenize(sentence)
input_field = TextField(tokenized_string[:-1], self._token_indexers)
output_field = TextField(tokenized_string[1:], self._output_indexer)
return Instance({'input_tokens': input_field, 'output_tokens': output_field})
# We need to constrain the sequence_length to ensure that
# the forward targets don't reach beyond the length of our dataset
sequence_length = min(sequence_length,
len(batched_file_tokens[0]) - batch_start_index - 1)
batch_inputs = [single_batch[batch_start_index:batch_start_index + sequence_length]
for single_batch in batched_file_tokens]
batch_forward_targets = [single_batch[batch_start_index + 1:batch_start_index + 1 + sequence_length]
for single_batch in batched_file_tokens]
input_field = ListField([TextField(single_batch, self._token_indexers) for
single_batch in batch_inputs])
forward_targets_field = ListField([TextField(single_batch, self._token_indexers) for
single_batch in batch_forward_targets])
if not self._bidirectional:
yield Instance({
"inputs": input_field,
"forward_targets": forward_targets_field
})
else:
batch_backward_targets = [single_batch[batch_start_index - 1:batch_start_index - 1 + sequence_length]
for single_batch in batched_file_tokens]
backward_targets_field = ListField([TextField(single_batch, self._token_indexers) for
single_batch in batch_backward_targets])
yield Instance({
"inputs": input_field,
"forward_targets": forward_targets_field,
"backward_targets": backward_targets_field
})
batch_start_index += sequence_length
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.language_modeling_iterator import LanguageModelingIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
79 changes: 79 additions & 0 deletions allennlp/data/iterators/language_modeling_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Iterable, Iterator
import logging

from allennlp.data.instance import Instance
from allennlp.data.iterators.basic_iterator import BasicIterator
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.data.dataset import Batch

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@DataIterator.register("language_modeling")
class LanguageModelingIterator(BasicIterator):
Copy link
Contributor

@rloganiv rloganiv Jun 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to give this object a more general name - it is useful for any problem where batching is done in the DatasetReader (e.g. the problem in #2828). Maybe something like StraightThroughIterator or TrivialIterator?

"""
An iterator used for language modeling of contiguous text.
This is essentially the same as the BasicIterator, but shuffling
is turned off, the batch size is set to 1, and maximum_samples_per_batch
is not set.

Parameters
----------
instances_per_epoch : ``int``, optional, (default = None)
If specified, each epoch will consist of precisely this many instances.
If not specified, each epoch will consist of a single pass through the dataset.
max_instances_in_memory : ``int``, optional, (default = None)
If specified, the iterator will load this many instances at a time into an
in-memory list and then produce batches from one such list at a time. This
could be useful if your instances are read lazily from disk.
cache_instances : ``bool``, optional, (default = False)
If true, the iterator will cache the tensorized instances in memory.
If false, it will do the tensorization anew each iteration.
track_epoch : ``bool``, optional, (default = False)
If true, each instance will get a ``MetadataField`` containing the epoch number.
"""

def __init__(self,
instances_per_epoch: int = None,
max_instances_in_memory: int = None,
cache_instances: bool = False,
track_epoch: bool = False) -> None:
super().__init__(batch_size=1,
instances_per_epoch=instances_per_epoch,
max_instances_in_memory=max_instances_in_memory,
cache_instances=cache_instances,
track_epoch=track_epoch,
maximum_samples_per_batch=None)

def __call__(self,
instances: Iterable[Instance],
num_epochs: int = None,
shuffle: bool = True) -> Iterator[TensorDict]:
# Set shuffle to False it is True
# TODO (nfliu): verify this is the right thing to do here.
if shuffle:
logger.info("LanguageModelingIterator does not shuffle instances.")
shuffle = False
for tensor_dict in super().__call__(instances=instances,
num_epochs=num_epochs,
shuffle=shuffle):
# Remove singleton dimensions from tensor dict produced
# by instances generated by LanguageModelingReader
for token_level in tensor_dict.get("inputs", {}):
tensor_dict["inputs"][token_level] = tensor_dict["inputs"][token_level].squeeze(0)
for token_level in tensor_dict.get("forward_targets", {}):
tensor_dict["forward_targets"][token_level] = tensor_dict[
"forward_targets"][token_level].squeeze(0)
for token_level in tensor_dict.get("backward_targets", {}):
tensor_dict["backward_targets"][token_level] = tensor_dict[
"backward_targets"][token_level].squeeze(0)
yield tensor_dict


def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
# Set shuffle to False it is True
if shuffle:
logger.info("LanguageModelingIterator does not shuffle instances.")
shuffle = False
yield from super()._create_batches(instances=instances,
shuffle=shuffle)
Loading