diff --git a/allennlp/data/dataset_readers/language_modeling.py b/allennlp/data/dataset_readers/language_modeling.py index 6b404cb1453..d9c4cfb7a51 100644 --- a/allennlp/data/dataset_readers/language_modeling.py +++ b/allennlp/data/dataset_readers/language_modeling.py @@ -1,97 +1,160 @@ -from typing import Dict +from typing import Dict, List import logging from overrides import overrides +import numpy +from allennlp.common.checks import ConfigurationError from allennlp.common.file_utils import cached_path -from allennlp.common.tqdm import Tqdm from allennlp.data.instance import Instance from allennlp.data.tokenizers.tokenizer import Tokenizer -from allennlp.data.tokenizers import WordTokenizer +from allennlp.data.tokenizers import Token, WordTokenizer from allennlp.data.dataset_readers.dataset_reader import DatasetReader from allennlp.data.token_indexers.token_indexer import TokenIndexer -from allennlp.data.fields import TextField +from allennlp.data.fields import ListField, TextField from allennlp.data.token_indexers import SingleIdTokenIndexer - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @DatasetReader.register("language_modeling") class LanguageModelingReader(DatasetReader): """ - Reads a text file and converts it into a ``Dataset`` suitable for training a language model. - - Note that there's one issue that needs to be fixed before this is actually usable for language - modeling - the way start and end tokens for sentences are handled is not correct; we need to - add a sentence splitter before this will be done right. + Reads a text file and converts it into a ``Dataset`` suitable for training + a language model. Parameters ---------- - tokens_per_instance : ``int``, optional (default=``None``) - If this is ``None``, we will have each training instance be a single sentence. If this is - not ``None``, we will instead take all sentences, including their start and stop tokens, - line them up, and split the tokens into groups of this number, for more efficient training - of the language model. + batch_size : ``int``, optional (default=``20``) + Batch size to use in language modeling. + truncated_bptt_size : ``int``, optional (default=``35``) + The sequence length to use for truncated backpropagation through time. + fuzz_truncated_bptt_size : ``bool``, optional (default=``True``) + If True, randomly perturb the truncated_bptt_size between batches. + bidirectional : ``bool``, optional (default=``False``) + If True, generate instances for bidirectional language modeling. tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) We use this ``Tokenizer`` for the text. See :class:`Tokenizer`. token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) - We use this to define the input representation for the text. See :class:`TokenIndexer`. - Note that the `output` representation will always be single token IDs - if you've specified - a ``SingleIdTokenIndexer`` here, we use the first one you specify. Otherwise, we create - one with default parameters. + We use this to define the input representation for the text. + See :class:`TokenIndexer`. + start_tokens : ``List[str]``, optional (default=``None``) + These are prepended to each line read by the dataset reader. + end_tokens : ``List[str]``, optional (default=``[""]``) + These are appended to each line read by the dataset reader. """ def __init__(self, - tokens_per_instance: int = None, + batch_size: int = 20, + truncated_bptt_size: int = 35, + fuzz_truncated_bptt_size: bool = True, + bidirectional: bool = False, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None, - lazy: bool = False) -> None: - super().__init__(lazy) + start_tokens: List[str] = None, + end_tokens: List[str] = [""]) -> None: + super().__init__(lazy=False) + self._batch_size = batch_size + if truncated_bptt_size < 2: + raise ConfigurationError("truncated_bptt_size cannot be less than 2.") + self._truncated_bptt_size = truncated_bptt_size + self._fuzz_truncated_bptt_size = fuzz_truncated_bptt_size + self._bidirectional = bidirectional self._tokenizer = tokenizer or WordTokenizer() self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} - self._tokens_per_instance = tokens_per_instance - # No matter how you want to represent the input, we'll always represent the output as a - # single token id. This code lets you learn a language model that concatenates word - # embeddings with character-level encoders, in order to predict the word token that comes - # next. - self._output_indexer: Dict[str, TokenIndexer] = None - for name, indexer in self._token_indexers.items(): - if isinstance(indexer, SingleIdTokenIndexer): - self._output_indexer = {name: indexer} - break - else: - self._output_indexer = {"tokens": SingleIdTokenIndexer()} + self._start_tokens = [Token(st) for st in (start_tokens or [])] + self._end_tokens = [Token(et) for et in (end_tokens or [])] + # Cache the batched tokens we read from data files. + self._all_batched_file_tokens = {} @overrides def _read(self, file_path: str): - # if `file_path` is a URL, redirect to the cache - file_path = cached_path(file_path) + if file_path not in self._all_batched_file_tokens: + logger.info('Loading data from %s', file_path) + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) - with open(file_path, "r") as text_file: - instance_strings = text_file.readlines() + # Read the contents of the file into one long list of tokens, + # adding start and/or end tokens as necessary. + file_tokens = [] + file_tokens.extend(self._start_tokens) + with open(file_path, "r") as text_file: + for line in text_file: + tokenized_line = self._tokenizer.tokenize(line) + file_tokens.extend(tokenized_line) + file_tokens.extend(self._end_tokens) - if self._tokens_per_instance is not None: - all_text = " ".join([x.replace("\n", " ").strip() for x in instance_strings]) - tokenized_text = self._tokenizer.tokenize(all_text) - num_tokens = self._tokens_per_instance + 1 - tokenized_strings = [] - logger.info("Creating dataset from all text in file: %s", file_path) - for index in Tqdm.tqdm(range(0, len(tokenized_text) - num_tokens, num_tokens - 1)): - tokenized_strings.append(tokenized_text[index:(index + num_tokens)]) + # Divide file_tokens into batch_size lists + # Work out how we can evenly split the dataset into batch_size parts + total_num_tokens_per_batch = len(file_tokens) // self._batch_size + if total_num_tokens_per_batch == 0: + # TODO (nfliu): figure out if this is the desired behavior + raise ValueError(f"There are {len(file_tokens)} tokens in the file, " + f"but batch size is {self._batch_size}. " + "batch size must be less than or equal to number of " + "tokens in the file.") + + # Trim off the remainder from file_tokens, so we can evenly divide it + # into batch_size lists. + file_tokens_for_even_split = file_tokens[:total_num_tokens_per_batch * + self._batch_size] + # Evenly divide the data into batch_size lists. + batched_file_tokens = [ + file_tokens_for_even_split[i:i + total_num_tokens_per_batch] for i in + range(0, len(file_tokens_for_even_split), total_num_tokens_per_batch)] + # Cache the tokens of the dataset we've just read. + self._all_batched_file_tokens[file_path] = batched_file_tokens else: - tokenized_strings = [self._tokenizer.tokenize(s) for s in instance_strings] + batched_file_tokens = self._all_batched_file_tokens[file_path] - for tokenized_string in tokenized_strings: - input_field = TextField(tokenized_string[:-1], self._token_indexers) - output_field = TextField(tokenized_string[1:], self._output_indexer) - yield Instance({'input_tokens': input_field, - 'output_tokens': output_field}) + # Iterate over the batched_file_tokens, yielding batches + # If bidirectional, we start at index 1 (so the first instance) has + # backward targets. Else, we start at index 0. + batch_start_index = 1 if self._bidirectional else 0 + # The max value of batch_start_index is len(batched_file_tokens[0]) - 2, + # leaving room for the target even when the final batch is size 1. + while batch_start_index < len(batched_file_tokens[0]) - 1: + if self._fuzz_truncated_bptt_size: + # This randomization is taken from the code for training the AWD-LSTM. + # (matrices of size (batch_size, truncated_bptt_size)) + fuzzy_truncated_bptt_size = ( + self._truncated_bptt_size if numpy.random.random() < 0.95 else + self._truncated_bptt_size / 2.) + # Prevent excessively small or negative sequence length + sequence_length = max(5, + int(numpy.random.normal(fuzzy_truncated_bptt_size, 5))) + # There's a very small chance that it could select a very long sequence + # length, resulting in OOM. So we cap it at no more than + # self._truncated_bptt_size + 10 + sequence_length = min(sequence_length, self._truncated_bptt_size + 10) + else: + sequence_length = self._truncated_bptt_size - @overrides - def text_to_instance(self, sentence: str) -> Instance: # type: ignore - # pylint: disable=arguments-differ - tokenized_string = self._tokenizer.tokenize(sentence) - input_field = TextField(tokenized_string[:-1], self._token_indexers) - output_field = TextField(tokenized_string[1:], self._output_indexer) - return Instance({'input_tokens': input_field, 'output_tokens': output_field}) + # We need to constrain the sequence_length to ensure that + # the forward targets don't reach beyond the length of our dataset + sequence_length = min(sequence_length, + len(batched_file_tokens[0]) - batch_start_index - 1) + batch_inputs = [single_batch[batch_start_index:batch_start_index + sequence_length] + for single_batch in batched_file_tokens] + batch_forward_targets = [single_batch[batch_start_index + 1:batch_start_index + 1 + sequence_length] + for single_batch in batched_file_tokens] + input_field = ListField([TextField(single_batch, self._token_indexers) for + single_batch in batch_inputs]) + forward_targets_field = ListField([TextField(single_batch, self._token_indexers) for + single_batch in batch_forward_targets]) + if not self._bidirectional: + yield Instance({ + "inputs": input_field, + "forward_targets": forward_targets_field + }) + else: + batch_backward_targets = [single_batch[batch_start_index - 1:batch_start_index - 1 + sequence_length] + for single_batch in batched_file_tokens] + backward_targets_field = ListField([TextField(single_batch, self._token_indexers) for + single_batch in batch_backward_targets]) + yield Instance({ + "inputs": input_field, + "forward_targets": forward_targets_field, + "backward_targets": backward_targets_field + }) + batch_start_index += sequence_length diff --git a/allennlp/data/iterators/__init__.py b/allennlp/data/iterators/__init__.py index 498cd5b83a1..1f90bad1206 100644 --- a/allennlp/data/iterators/__init__.py +++ b/allennlp/data/iterators/__init__.py @@ -7,4 +7,5 @@ from allennlp.data.iterators.basic_iterator import BasicIterator from allennlp.data.iterators.bucket_iterator import BucketIterator from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator +from allennlp.data.iterators.language_modeling_iterator import LanguageModelingIterator from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator diff --git a/allennlp/data/iterators/language_modeling_iterator.py b/allennlp/data/iterators/language_modeling_iterator.py new file mode 100644 index 00000000000..f592c3908c3 --- /dev/null +++ b/allennlp/data/iterators/language_modeling_iterator.py @@ -0,0 +1,79 @@ +from typing import Iterable, Iterator +import logging + +from allennlp.data.instance import Instance +from allennlp.data.iterators.basic_iterator import BasicIterator +from allennlp.data.iterators.data_iterator import DataIterator, TensorDict +from allennlp.data.dataset import Batch + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +@DataIterator.register("language_modeling") +class LanguageModelingIterator(BasicIterator): + """ + An iterator used for language modeling of contiguous text. + This is essentially the same as the BasicIterator, but shuffling + is turned off, the batch size is set to 1, and maximum_samples_per_batch + is not set. + + Parameters + ---------- + instances_per_epoch : ``int``, optional, (default = None) + If specified, each epoch will consist of precisely this many instances. + If not specified, each epoch will consist of a single pass through the dataset. + max_instances_in_memory : ``int``, optional, (default = None) + If specified, the iterator will load this many instances at a time into an + in-memory list and then produce batches from one such list at a time. This + could be useful if your instances are read lazily from disk. + cache_instances : ``bool``, optional, (default = False) + If true, the iterator will cache the tensorized instances in memory. + If false, it will do the tensorization anew each iteration. + track_epoch : ``bool``, optional, (default = False) + If true, each instance will get a ``MetadataField`` containing the epoch number. + """ + + def __init__(self, + instances_per_epoch: int = None, + max_instances_in_memory: int = None, + cache_instances: bool = False, + track_epoch: bool = False) -> None: + super().__init__(batch_size=1, + instances_per_epoch=instances_per_epoch, + max_instances_in_memory=max_instances_in_memory, + cache_instances=cache_instances, + track_epoch=track_epoch, + maximum_samples_per_batch=None) + + def __call__(self, + instances: Iterable[Instance], + num_epochs: int = None, + shuffle: bool = True) -> Iterator[TensorDict]: + # Set shuffle to False it is True + # TODO (nfliu): verify this is the right thing to do here. + if shuffle: + logger.info("LanguageModelingIterator does not shuffle instances.") + shuffle = False + for tensor_dict in super().__call__(instances=instances, + num_epochs=num_epochs, + shuffle=shuffle): + # Remove singleton dimensions from tensor dict produced + # by instances generated by LanguageModelingReader + for token_level in tensor_dict.get("inputs", {}): + tensor_dict["inputs"][token_level] = tensor_dict["inputs"][token_level].squeeze(0) + for token_level in tensor_dict.get("forward_targets", {}): + tensor_dict["forward_targets"][token_level] = tensor_dict[ + "forward_targets"][token_level].squeeze(0) + for token_level in tensor_dict.get("backward_targets", {}): + tensor_dict["backward_targets"][token_level] = tensor_dict[ + "backward_targets"][token_level].squeeze(0) + yield tensor_dict + + + def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: + # Set shuffle to False it is True + if shuffle: + logger.info("LanguageModelingIterator does not shuffle instances.") + shuffle = False + yield from super()._create_batches(instances=instances, + shuffle=shuffle) diff --git a/allennlp/models/language_model.py b/allennlp/models/language_model.py index 0698fade287..5940c54a2ca 100644 --- a/allennlp/models/language_model.py +++ b/allennlp/models/language_model.py @@ -222,7 +222,10 @@ def num_layers(self) -> int: "does not report how many layers it has.") def forward(self, # type: ignore - source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: + inputs: Dict[str, torch.LongTensor] = None, + forward_targets: Dict[str, torch.LongTensor] = None, + backward_targets: Dict[str, torch.LongTensor] = None, + source: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: """ Computes the averaged forward (and backward, if language model is bidirectional) LM loss from the batch. @@ -233,8 +236,29 @@ def forward(self, # type: ignore Parameters ---------- - tokens: ``torch.Tensor``, required. - The output of ``Batch.as_tensor_dict()`` for a batch of sentences. + inputs: ``torch.Tensor``, optional. + The output of ``Batch.as_tensor_dict()`` for a batch of sentences, e.g. + as generated by the LanguageModelingReader. This tensor represents the + input tokens to be contextualized. If this is provided, + ``source`` cannot be provided. + forward_targets: ``torch.Tensor``, optional. + The output of ``Batch.as_tensor_dict()`` for a batch of sentences, e.g. + as generated by the LanguageModelingReader. This tensor represents the + forward-direction targets corresponding to the ``inputs``. If this is + provided, ``inputs`` must be provided (and ``source`` thus cannot be + provided). + backward_targets: ``torch.Tensor``, optional. + The output of ``Batch.as_tensor_dict()`` for a batch of sentences, e.g. + as generated by the LanguageModelingReader. This tensor represents the + backward-direction targets corresponding to the ``inputs``. If this is + provided, ``inputs`` must be provided (and ``source`` thus cannot be + provided). + source: ``torch.Tensor``, optional. + The output of ``Batch.as_tensor_dict()`` for a batch of sentences, e.g. + as generated by the SimpleLanguageModelingDatasetReader. + If calculating the loss, we infer the prediction targets from this input, + which is assumed to have a START and STOP symbol at the end of each instance. + If this is provided, ``inputs`` or ``targets`` cannot be provided. Returns ------- @@ -258,10 +282,69 @@ def forward(self, # type: ignore (batch_size, timesteps) mask for the embeddings """ # pylint: disable=arguments-differ - mask = get_text_field_mask(source) + if source is not None: + # If source is provided, inputs, forward_targets, and backward_targets + # should not be provided. + if inputs is not None: + raise ValueError("Received non-None values for both inputs and " + "source. Only one should be provided.") + if forward_targets is not None: + raise ValueError("Received non-None values for 'source' and " + "'forward_targets'. When using 'source', the labels " + "are automatically inferred, so forward_targets " + "should not provided. If you wish to use " + "'forward_targets', provide the corresponding 'inputs' " + "instead of using 'source'.") + if backward_targets is not None: + raise ValueError("Received non-None values for 'source' and " + "'backward_targets'. When using 'source', the labels " + "are automatically inferred, so backward_targets " + "should not provided. If you wish to use " + "'backward_targets', provide the corresponding 'inputs' " + "instead of using 'source'.") + + # Calculate the forward_targets and backward_targets (if applicable) using + # the source. + token_ids = source.get("tokens") + if token_ids is not None: + # Get the targets if we have target tokens. + forward_targets = torch.zeros_like(token_ids) + forward_targets[:, 0:-1] = token_ids[:, 1:] + if self._bidirectional: + backward_targets = torch.zeros_like(token_ids) + backward_targets[:, 1:] = token_ids[:, 0:-1] + else: + backward_targets = None + # Set "source" to "inputs", for consistency below. + inputs = source + elif inputs is not None: + # Get forward and backward targets from forward_targets + # and backward_targets, if provided. + if forward_targets is not None: + forward_targets = forward_targets.get("tokens") + if self._bidirectional and backward_targets is not None: + backward_targets = backward_targets.get("tokens") + if self._bidirectional: + # Error if one of forward_targets or backward_targets is provided, + # but not both (or none). + if bool(forward_targets is not None) ^ bool(backward_targets is not None): + raise ValueError("Bidirectional LanguageModel received only one " + "of 'forward_targets' or 'backward_targets'. " + "Both 'forward_targets' and 'backward_targets' are " + "necessary to calculate loss in the bidirectional " + "case.") + else: + # Error if backward_targets is provided, since this LM is not bidirectional. + if backward_targets is not None: + raise ValueError("LanguageModel is not bidirectional, but " + "backward_targets were provided. In an unidirectional " + "LanguageModel, only forward_targets should be " + "provided.") + + mask = get_text_field_mask(inputs) # shape (batch_size, timesteps, embedding_size) - embeddings = self._text_field_embedder(source) + embeddings = self._text_field_embedder(inputs) # Either the top layer or all layers. contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer( @@ -271,20 +354,11 @@ def forward(self, # type: ignore return_dict = {} # If we have target tokens, calculate the loss. - token_ids = source.get("tokens") - if token_ids is not None: + # We only check forward_targets here, since they must always be provided + # regardless of the bidirectionality of the language model. + if forward_targets is not None: assert isinstance(contextual_embeddings, torch.Tensor) - # Use token_ids to compute targets - forward_targets = torch.zeros_like(token_ids) - forward_targets[:, 0:-1] = token_ids[:, 1:] - - if self._bidirectional: - backward_targets = torch.zeros_like(token_ids) - backward_targets[:, 1:] = token_ids[:, 0:-1] - else: - backward_targets = None - # add dropout contextual_embeddings_with_dropout = self._dropout(contextual_embeddings) diff --git a/allennlp/tests/data/dataset_readers/language_modeling_dataset_test.py b/allennlp/tests/data/dataset_readers/language_modeling_dataset_test.py index a175ef794cc..d1f55d7280c 100644 --- a/allennlp/tests/data/dataset_readers/language_modeling_dataset_test.py +++ b/allennlp/tests/data/dataset_readers/language_modeling_dataset_test.py @@ -1,31 +1,243 @@ # pylint: disable=no-self-use,invalid-name -import pytest +import numpy from allennlp.data.dataset_readers import LanguageModelingReader from allennlp.common.util import ensure_list from allennlp.common.testing import AllenNlpTestCase class TestLanguageModelingDatasetReader: - @pytest.mark.parametrize("lazy", (True, False)) - def test_read_from_file(self, lazy): - reader = LanguageModelingReader(tokens_per_instance=3, lazy=lazy) + def test_read_from_file_no_fuzz_is_deterministic(self): + """ + The dataset is split into 4 batches, becoming: - instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'language_modeling.txt')) - # The last potential instance is left out, which is ok, because we don't have an end token - # in here, anyway. - assert len(instances) == 5 + [[This, is, a, sentence], + [for, language, modelling, .], + [, Here, 's, another], + [one, for, extra, language]] - assert [t.text for t in instances[0].fields["input_tokens"].tokens] == ["This", "is", "a"] - assert [t.text for t in instances[0].fields["output_tokens"].tokens] == ["is", "a", "sentence"] + Then, since our truncated bptt size is 2, our the inputs of our first batch + consists of: + [[This, is], + [for, language], + [, Here], + [one, for]] - assert [t.text for t in instances[1].fields["input_tokens"].tokens] == ["sentence", "for", "language"] - assert [t.text for t in instances[1].fields["output_tokens"].tokens] == ["for", "language", "modelling"] + The second batch consists of: + [[a], + [modelling], + ['s], + [extra]] - assert [t.text for t in instances[2].fields["input_tokens"].tokens] == ["modelling", ".", "Here"] - assert [t.text for t in instances[2].fields["output_tokens"].tokens] == [".", "Here", "'s"] + Note that the second batch has a shorter sequence length of 1, and we do not + predict labels for the final words in the batch. + """ + # Results should be identical if we run twice. + for _ in range(2): + reader = LanguageModelingReader(batch_size=4, + truncated_bptt_size=2, + fuzz_truncated_bptt_size=False) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / + 'language_modeling.txt')) + # This should match the batch size + assert len(instances[0].fields["inputs"].field_list) == 4 + assert len(instances[0].fields["forward_targets"].field_list) == 4 + # This is the number of batches generated + assert len(instances) == 2 - assert [t.text for t in instances[3].fields["input_tokens"].tokens] == ["'s", "another", "one"] - assert [t.text for t in instances[3].fields["output_tokens"].tokens] == ["another", "one", "for"] + first_instance_inputs = [["This", "is"], + ["for", "language"], + ["", "Here"], + ["one", "for"]] + first_instance_forward_targets = [["is", "a"], + ["language", "modelling"], + ["Here", "'s"], + ["for", "extra"]] - assert [t.text for t in instances[4].fields["input_tokens"].tokens] == ["for", "extra", "language"] - assert [t.text for t in instances[4].fields["output_tokens"].tokens] == ["extra", "language", "modelling"] + first_instance_generated_inputs = [ + [x.text for x in instances[0].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[0].fields["inputs"].field_list))] + assert first_instance_generated_inputs == first_instance_inputs + first_instance_generated_forward_targets = [ + [x.text for x in instances[0].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["forward_targets"].field_list))] + assert first_instance_generated_forward_targets == first_instance_forward_targets + + second_instance_inputs = [["a"], + ["modelling"], + ["'s"], + ["extra"]] + second_instance_forward_targets = [["sentence"], + ["."], + ["another"], + ["language"]] + second_instance_generated_inputs = [ + [x.text for x in instances[1].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[1].fields["inputs"].field_list))] + assert second_instance_generated_inputs == second_instance_inputs + second_instance_generated_forward_targets = [ + [x.text for x in instances[1].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[1].fields["forward_targets"].field_list))] + assert second_instance_generated_forward_targets == second_instance_forward_targets + + def test_read_from_file(self): + """ + The dataset is split into 2 batches, becoming: + + [[This, is, a, sentence, for, language, modelling, ., ], + [Here, 's, another, one, for, extra, language, modelling, .]] + + Our truncated bptt size is 2, but fuzz_truncated_bptt_size is True. So + the sequence length is randomly perturbed, becoming 5. As a result, + the inputs are: + + [[This, is, a, sentence, for], + [Here, 's, another, one, for]] + + The second batch consists of: + [[language, modelling, .], + [extra, language, modelling]] + + Note that the second batch has a shorter sequence length of 3, and we do not + predict labels for the final words in the batch. + """ + numpy.random.seed(seed=0) + reader = LanguageModelingReader(batch_size=2, truncated_bptt_size=2) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / + 'language_modeling.txt')) + # This should match the batch size + assert len(instances[0].fields["inputs"].field_list) == 2 + # This is the number of batches generated + assert len(instances) == 2 + + first_instance_inputs = [["This", "is", "a", "sentence", "for"], + ["Here", "'s", "another", "one", "for"]] + first_instance_forward_targets = [["is", "a", "sentence", "for", "language"], + ["'s", "another", "one", "for", "extra"]] + first_instance_generated_inputs = [ + [x.text for x in instances[0].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[0].fields["inputs"].field_list))] + assert first_instance_generated_inputs == first_instance_inputs + first_instance_generated_forward_targets = [ + [x.text for x in instances[0].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["forward_targets"].field_list))] + assert first_instance_generated_forward_targets == first_instance_forward_targets + + second_instance_inputs = [["language", "modelling", "."], + ["extra", "language", "modelling"]] + second_instance_forward_targets = [["modelling", ".", ""], + ["language", "modelling", "."]] + second_instance_generated_inputs = [ + [x.text for x in instances[1].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[1].fields["inputs"].field_list))] + assert second_instance_generated_inputs == second_instance_inputs + second_instance_generated_forward_targets = [ + [x.text for x in instances[1].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[1].fields["forward_targets"].field_list))] + assert second_instance_generated_forward_targets == second_instance_forward_targets + + +class TestBidirectionalLanguageModelingDatasetReader: + def test_read_from_file_no_fuzz_is_deterministic(self): + """ + The dataset is split into 4 batches, becoming: + + [[This, is, a, sentence], + [for, language, modelling, .], + [, Here, 's, another], + [one, for, extra, language]] + """ + # Results should be identical if we run twice. + for _ in range(2): + reader = LanguageModelingReader(batch_size=4, + truncated_bptt_size=2, + fuzz_truncated_bptt_size=False, + bidirectional=True) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / + 'language_modeling.txt')) + # This should match the batch size + assert len(instances[0].fields["inputs"].field_list) == 4 + assert len(instances[0].fields["forward_targets"].field_list) == 4 + # This is the number of batches generated + assert len(instances) == 1 + + first_instance_inputs = [["is", "a"], + ["language", "modelling"], + ["Here", "'s"], + ["for", "extra"]] + first_instance_forward_targets = [["a", "sentence"], + ["modelling", "."], + ["'s", "another"], + ["extra", "language"]] + first_instance_backward_targets = [["This", "is"], + ["for", "language"], + ["", "Here"], + ["one", "for"]] + + first_instance_generated_inputs = [ + [x.text for x in instances[0].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[0].fields["inputs"].field_list))] + assert first_instance_generated_inputs == first_instance_inputs + first_instance_generated_forward_targets = [ + [x.text for x in instances[0].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["forward_targets"].field_list))] + assert first_instance_generated_forward_targets == first_instance_forward_targets + first_instance_generated_backward_targets = [ + [x.text for x in instances[0].fields["backward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["backward_targets"].field_list))] + assert first_instance_generated_backward_targets == first_instance_backward_targets + + def test_read_from_file(self): + """ + The dataset is split into 2 batches, becoming: + + [[This, is, a, sentence, for, language, modelling, ., ], + [Here, 's, another, one, for, extra, language, modelling, .]] + """ + numpy.random.seed(seed=0) + reader = LanguageModelingReader(batch_size=2, + truncated_bptt_size=2, + bidirectional=True) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / + 'language_modeling.txt')) + # This should match the batch size + assert len(instances[0].fields["inputs"].field_list) == 2 + # This is the number of batches generated + assert len(instances) == 2 + + first_instance_inputs = [["is", "a", "sentence", "for", "language"], + ["'s", "another", "one", "for", "extra"]] + first_instance_forward_targets = [["a", "sentence", "for", "language", "modelling"], + ["another", "one", "for", "extra", "language"]] + first_instance_backward_targets = [["This", "is", "a", "sentence", "for"], + ["Here", "'s", "another", "one", "for"]] + first_instance_generated_inputs = [ + [x.text for x in instances[0].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[0].fields["inputs"].field_list))] + assert first_instance_generated_inputs == first_instance_inputs + first_instance_generated_forward_targets = [ + [x.text for x in instances[0].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["forward_targets"].field_list))] + assert first_instance_generated_forward_targets == first_instance_forward_targets + first_instance_generated_backward_targets = [ + [x.text for x in instances[0].fields["backward_targets"].field_list[i].tokens] for + i in range(len(instances[0].fields["backward_targets"].field_list))] + assert first_instance_generated_backward_targets == first_instance_backward_targets + + second_instance_inputs = [["modelling", "."], + ["language", "modelling"]] + second_instance_forward_targets = [[".", ""], + ["modelling", "."]] + second_instance_backward_targets = [["language", "modelling"], + ["extra", "language"]] + second_instance_generated_inputs = [ + [x.text for x in instances[1].fields["inputs"].field_list[i].tokens] for + i in range(len(instances[1].fields["inputs"].field_list))] + assert second_instance_generated_inputs == second_instance_inputs + second_instance_generated_forward_targets = [ + [x.text for x in instances[1].fields["forward_targets"].field_list[i].tokens] for + i in range(len(instances[1].fields["forward_targets"].field_list))] + assert second_instance_generated_forward_targets == second_instance_forward_targets + second_instance_generated_backward_targets = [ + [x.text for x in instances[1].fields["backward_targets"].field_list[i].tokens] for + i in range(len(instances[1].fields["backward_targets"].field_list))] + assert second_instance_generated_backward_targets == second_instance_backward_targets diff --git a/allennlp/tests/fixtures/language_model/experiment_contiguous.jsonnet b/allennlp/tests/fixtures/language_model/experiment_contiguous.jsonnet new file mode 100644 index 00000000000..5517ad0f505 --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_contiguous.jsonnet @@ -0,0 +1,8 @@ +local config = import "experiment_contiguous_unsampled.jsonnet"; + +config + { + "model"+: { + "num_samples": 10, + "sparse_embeddings": true + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_contiguous_transformer.jsonnet b/allennlp/tests/fixtures/language_model/experiment_contiguous_transformer.jsonnet new file mode 100644 index 00000000000..466ea21b491 --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_contiguous_transformer.jsonnet @@ -0,0 +1,16 @@ +local config = import "experiment_contiguous_unsampled.jsonnet"; + +config + { + "model"+: { + "num_samples": 10, + "sparse_embeddings": true, + "contextualizer": { + "type": "bidirectional_language_model_transformer", + "input_dim": 16, + "hidden_dim": 7, + "num_layers": 3, + "dropout": 0.1, + "input_dropout": 0.1 + } + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_contiguous_unsampled.jsonnet b/allennlp/tests/fixtures/language_model/experiment_contiguous_unsampled.jsonnet new file mode 100644 index 00000000000..a1a60343c6d --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_contiguous_unsampled.jsonnet @@ -0,0 +1,19 @@ +local config = import "experiment_unsampled.jsonnet"; + +config + { + "dataset_reader"+: { + "type": "language_modeling", + "batch_size": 2, + "bidirectional": true, + }, + "model"+: { + "contextualizer" +: { + // This is necessary for contiguous text LMs + "stateful": true + } + }, + "iterator"+: { + "type": "language_modeling", + batch_size :: super.batch_size + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous.jsonnet b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous.jsonnet new file mode 100644 index 00000000000..db5f48ec0c0 --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous.jsonnet @@ -0,0 +1,8 @@ +local config = import "experiment_unidirectional_contiguous_unsampled.jsonnet"; + +config + { + "model"+: { + "num_samples": 10, + "sparse_embeddings": true + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_transformer.jsonnet b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_transformer.jsonnet new file mode 100644 index 00000000000..816b9ef8284 --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_transformer.jsonnet @@ -0,0 +1,18 @@ +local config = import "experiment_unidirectional_contiguous_unsampled.jsonnet"; + +config + { + "model"+: { + "num_samples": 10, + "sparse_embeddings": true, + "contextualizer": { + "type": "stacked_self_attention", + "input_dim": 16, + "hidden_dim": 20, + "projection_dim": 6, + "feedforward_hidden_dim": 5, + "num_attention_heads": 3, + "num_layers": 3, + "dropout_prob": 0.1 + } + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_unsampled.jsonnet b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_unsampled.jsonnet new file mode 100644 index 00000000000..bbec95702ef --- /dev/null +++ b/allennlp/tests/fixtures/language_model/experiment_unidirectional_contiguous_unsampled.jsonnet @@ -0,0 +1,20 @@ +local config = import "experiment_unsampled.jsonnet"; + +config + { + "dataset_reader"+: { + "type": "language_modeling", + "batch_size": 2, + }, + "model"+: { + "bidirectional": false, + "contextualizer" +: { + "bidirectional": false, + // This is necessary for contiguous text LMs + "stateful": true + } + }, + "iterator"+: { + "type": "language_modeling", + batch_size :: super.batch_size + } +} diff --git a/allennlp/tests/fixtures/language_model/experiment_unsampled.jsonnet b/allennlp/tests/fixtures/language_model/experiment_unsampled.jsonnet index 1a17b111fbe..1bcbd9d52c3 100644 --- a/allennlp/tests/fixtures/language_model/experiment_unsampled.jsonnet +++ b/allennlp/tests/fixtures/language_model/experiment_unsampled.jsonnet @@ -47,7 +47,7 @@ "contextualizer": { "type": "lstm", "bidirectional": true, - "num_layers": 3, + "num_layers": 1, "input_size": 16, "hidden_size": 7, } diff --git a/allennlp/tests/models/language_model_test.py b/allennlp/tests/models/language_model_test.py index 8d8782db2a2..44871807260 100644 --- a/allennlp/tests/models/language_model_test.py +++ b/allennlp/tests/models/language_model_test.py @@ -80,6 +80,88 @@ def test_unidirectional_language_model_can_train_save_and_load(self): "_contextualizer.feedforward_layer_norm_0.gamma", "_contextualizer.feedforward_layer_norm_0.beta"}) +class TestUnidirectionalContiguousLanguageModel(ModelTestCase): + def setUp(self): + super().setUp() + + self.expected_embedding_shape = (2, 6, 7) + self.bidirectional = False + + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / + 'experiment_unidirectional_contiguous.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + + # pylint: disable=no-member + def test_unidirectional_language_model_can_train_save_and_load(self): + self.ensure_model_can_train_save_and_load(self.param_file) + + def test_forward_pass_runs_correctly(self): + training_tensors = self.dataset.as_tensor_dict() + # Note: The dataset itself generates an extra singleton dimension in the + # first dimension. This dimension is squeezed out in the + # LanguageModelingIterator, but we need to do it manually here. + for token_level in training_tensors.get("inputs", {}): + training_tensors["inputs"][token_level] = training_tensors["inputs"][token_level].squeeze(0) + for token_level in training_tensors.get("forward_targets", {}): + training_tensors["forward_targets"][token_level] = training_tensors[ + "forward_targets"][token_level].squeeze(0) + for token_level in training_tensors.get("backward_targets", {}): + training_tensors["backward_targets"][token_level] = training_tensors[ + "backward_targets"][token_level].squeeze(0) + + result = self.model(**training_tensors) + + assert set(result) == {"loss", "forward_loss", "backward_loss", "lm_embeddings", + "noncontextual_token_embeddings", "mask", "batch_weight"} + + embeddings = result["lm_embeddings"] + assert tuple(embeddings.shape) == self.expected_embedding_shape + + loss = result["loss"].item() + forward_loss = result["forward_loss"].item() + if self.bidirectional: + backward_loss = result["backward_loss"].item() + np.testing.assert_almost_equal(loss, (forward_loss + backward_loss) / 2, + decimal=3) + else: + np.testing.assert_almost_equal(loss, forward_loss, decimal=3) + assert result["backward_loss"] is None + + def test_mismatching_contextualizer_unidirectionality_throws_configuration_error(self): + params = Params.from_file(self.param_file) + # Make the contextualizer unidirectionality wrong - it should be + # False to match the language model. + params["model"]["contextualizer"]["bidirectional"] = (not self.bidirectional) + with pytest.raises(ConfigurationError): + Model.from_params(vocab=self.vocab, params=params.get("model")) + +class TestUnidirectionalContiguousLanguageModelUnsampled(TestUnidirectionalContiguousLanguageModel): + def setUp(self): + super().setUp() + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / + 'experiment_unidirectional_contiguous_unsampled.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + +class TestUnidirectionalContiguousLanguageModelTransformer(TestUnidirectionalContiguousLanguageModel): + def setUp(self): + super().setUp() + + self.expected_embedding_shape = (2, 6, 20) + + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / + 'experiment_unidirectional_contiguous_transformer.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + + # pylint: disable=no-member + def test_unidirectional_language_model_can_train_save_and_load(self): + # Ignore layer 0 feedforward layer norm parameters, since + # they are not used. + self.ensure_model_can_train_save_and_load( + self.param_file, gradients_to_ignore={ + "_contextualizer.feedforward_layer_norm_0.gamma", + "_contextualizer.feedforward_layer_norm_0.beta"}) + + class TestBidirectionalLanguageModel(TestUnidirectionalLanguageModel): def setUp(self): super().setUp() @@ -104,3 +186,28 @@ def setUp(self): self.set_up_model(self.FIXTURES_ROOT / 'language_model' / 'experiment_transformer.jsonnet', self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + +class TestBidirectionalContiguousLanguageModel(TestUnidirectionalContiguousLanguageModel): + def setUp(self): + super().setUp() + + self.expected_embedding_shape = (2, 5, 14) + self.bidirectional = True + + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / 'experiment_contiguous.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + +class TestBidirectionalContiguousLanguageModelUnsampled(TestBidirectionalContiguousLanguageModel): + def setUp(self): + super().setUp() + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / 'experiment_contiguous_unsampled.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt') + +class TestBidirectionalContiguousLanguageModelTransformer(TestBidirectionalContiguousLanguageModel): + def setUp(self): + super().setUp() + + self.expected_embedding_shape = (2, 5, 32) + + self.set_up_model(self.FIXTURES_ROOT / 'language_model' / 'experiment_contiguous_transformer.jsonnet', + self.FIXTURES_ROOT / 'language_model' / 'sentences.txt')