Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Decompose LanguageModel contextualizer into forward_ and backward_ contextualizer #2438

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 134 additions & 13 deletions allennlp/models/language_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Tuple, Union
import warnings

import torch
import numpy as np
Expand Down Expand Up @@ -74,6 +75,23 @@ class LanguageModel(Model):
contextualizer: ``Seq2SeqEncoder``
Used to "contextualize" the embeddings. As described above,
this encoder must not cheat by peeking ahead.

.. deprecated:: 0.8.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: To me this seems like we are ripping out quite a good api for one which is more complicated. We already have a trained BidirectionalLanguageModel which uses a bidirectional transformer, which is very useful to many people (e.g github issues/ Swabha is using it in her research), which makes me unsure that depreciating a key component of it is a good idea. Do we know that bidirectional language modeling of contiguous text is worth it? I can't imagine a downstream task where you actually need unlimited context (and even if there was one, how that would practically work). Is it worth ripping this into a separate repo, making the changes and double checking that contiguous text does something useful?

Additionally, we now have a BidirectionalLanguageModel subclass which is subsumed by this code (I think). This is "bad code smell" to me when we have a subclass which only differs in constructor arguments and we only have a single example of how those arguments might differ (it triggers my MultiClassMultipleChoiceMemoryNetwork(MemoryNetwork, MultiClass, MultipleChoice) sense from DeepQa).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the API decision to others, here's some motivation for enabling modeling of longer contexts: https://ai.googleblog.com/2019/01/transformer-xl-unleashing-potential-of.html. They basically make a stateful transformer, showing pretty big gains.

``contextualizer`` was deprecated in version 0.8.2 . It was
replaced with two more flexible arguments: ``forward_contextualizer``
and ``backward_contextualizer``, in order to enable bidirectional
language modeling of contiguous text. It will be removed in version 0.10 .

forward_contextualizer: ``Seq2SeqEncoder``
Used to "contextualize" the embeddings for a forward-direction LM.
As described above, this encoder must not cheat by peeking ahead.
backward_contextualizer: ``Seq2SeqEncoder``
Used to "contextualize" the embeddings for a backward-direction LM.
The contextualizer should operate from left to right; the the order of the
text in the backward inputs is assumed to have been flipped (e.g., by your
DatasetReader). If provided, the size of its output must match that of
the ``forward_contextualizer``.
As described above, this encoder must not cheat by peeking ahead.
dropout: ``float``, optional (default: None)
If specified, dropout is applied to the contextualized embeddings before computation of
the softmax. The contextualized embeddings themselves are returned without dropout.
Expand All @@ -91,7 +109,9 @@ class LanguageModel(Model):
def __init__(self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
contextualizer: Seq2SeqEncoder,
contextualizer: Seq2SeqEncoder = None,
forward_contextualizer: Seq2SeqEncoder = None,
backward_contextualizer: Seq2SeqEncoder = None,
dropout: float = None,
num_samples: int = None,
sparse_embeddings: bool = False,
Expand All @@ -100,22 +120,98 @@ def __init__(self,
super().__init__(vocab)
self._text_field_embedder = text_field_embedder

if contextualizer.is_bidirectional() is not bidirectional:
# Only true when contextualizer is non-None and bidirectional is True
self._use_contextualizer_arg = False
if contextualizer is not None and (forward_contextualizer is not None or
backward_contextualizer is not None):
raise ConfigurationError(
"Bidirectionality of contextualizer must match bidirectionality of "
"language model. "
f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
f"language model bidirectional: {bidirectional}")

self._contextualizer = contextualizer
"Cannot provide both contextualizer and either "
"forward_contextualizer or backward_contextualizer.")

if contextualizer is not None:
warnings.warn("``contextualizer`` was deprecated in version 0.8.2 . It was "
"replaced with two more flexible arguments: "
"``forward_contextualizer`` and ``backward_contextualizer``. "
"It will be removed in version 0.10 .",
DeprecationWarning)
if contextualizer.is_bidirectional() is not bidirectional:
raise ConfigurationError(
"Bidirectionality of contextualizer must match bidirectionality "
"of language model. "
f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, "
f"language model bidirectional: {bidirectional}")
if contextualizer.is_bidirectional():
warnings.warn(
"When using a bidirectional contextualizer, it's crucial that "
"the contextualizer does not cheat by looking ahead. For "
"instance, if you're using a multi-layer bidirectional RNN "
"here, the model is cheating because layers >= 2 use opposite "
"direection inputs (a single-layer bidirectional RNN is "
"thus fine). See the BidirectionalLanguageModelTransformer "
"for an example of how to properly ensure that a multilayer "
"bidirectional contextualizer doesn't inadvertently cheat, or "
"provide values for the forward_contextualizer and "
"backward_contextualizer arguments instead.")
self._use_contextualizer_arg = True
else:
# Unidirectional LM with unidirectional contextualizer, so just set
# forward_contextualizer to contextualizer.
forward_contextualizer = contextualizer
contextualizer = None
# If self._use_contextualizer_arg is True, this is non-None. Else, it is None.
self._contextualizer = contextualizer

# ``contextualizer`` logic handled, do error checking for
# forward_contextualizer and backward_contextualizer
if bidirectional and (forward_contextualizer is None or
backward_contextualizer is None):
if not self._use_contextualizer_arg:
raise ConfigurationError(
"LanguageModel bidirectional is True, but did not "
"provide forward_contextualizer and backward_contextualizer. "
f"Got forward_contextualizer: {forward_contextualizer} and "
f"backward_contextualizer: {backward_contextualizer}")
if not self._use_contextualizer_arg and forward_contextualizer is None:
raise ConfigurationError(
"The forward_contextualizer argument is required.")
if not bidirectional and backward_contextualizer is not None:
raise ConfigurationError(
"LanguageModel bidirectional is False, so "
"backward_contextualizer should not be provided."
f"Got backward_contextualizer: {backward_contextualizer}")
# Ensure that forward_contextualizer and backward_contextualizer
# are unidirectional
if forward_contextualizer and forward_contextualizer.is_bidirectional():
raise ConfigurationError("forward_contextualizer should not be "
"bidirectional.")
if backward_contextualizer and backward_contextualizer.is_bidirectional():
raise ConfigurationError("backward_contextualizer should not be "
"bidirectional.")

self._forward_contextualizer = forward_contextualizer
self._backward_contextualizer = backward_contextualizer
self._bidirectional = bidirectional

# The dimension for making predictions just in the forward
# (or backward) direction.
# They must be the same. TODO (nfliu): relax this assumption
if self._bidirectional:
self._forward_dim = contextualizer.get_output_dim() // 2
if self._use_contextualizer_arg:
self._forward_dim = self._contextualizer.get_output_dim() // 2
else:
if (self._forward_contextualizer.get_output_dim() !=
self._backward_contextualizer.get_output_dim()):
raise ConfigurationError(
"forward_contextualizer and backward_contextualizer "
"must have the same output dimension. "
"forward_contextualizer output dimension is "
f"{self._forward_contextualizer.get_output_dim()}, while"
"backward_contextualizer output dimension is "
f"{self._forward_contextualizer.get_output_dim()}")
self._forward_dim = self._forward_contextualizer.get_output_dim()
else:
self._forward_dim = contextualizer.get_output_dim()
# If bidirectional is False, self._use_contextualizer_arg is False.
self._forward_dim = self._forward_contextualizer.get_output_dim()

# TODO(joelgrus): more sampled softmax configuration options, as needed.
if num_samples is not None:
Expand Down Expand Up @@ -264,9 +360,34 @@ def forward(self, # type: ignore
embeddings = self._text_field_embedder(source)

# Either the top layer or all layers.
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer(
embeddings, mask
)
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = None
if self._use_contextualizer_arg:
contextual_embeddings = self._contextualizer(embeddings, mask)
else:
contextual_embeddings = self._forward_contextualizer(embeddings, mask)
if self._bidirectional:
backward_contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = (
self._backward_contextualizer(embeddings, mask))
# Concatenate the backward contextual embeddings to the
# forward contextual embeddings
if (isinstance(contextual_embeddings, list) and
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brendan-ai2 i'm struggling to test this case (where contextual_embeddings and backward_contextual_embeddings are lists), since I'm not sure when they would return lists :) any pointers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blargh. :/ The contextualizer can be made to return all layers for embedding.

/~https://github.com/allenai/allennlp/blob/master/allennlp/modules/token_embedders/language_model_token_embedder.py#L61

There basically isn't a real API in place right now as this is specific to the transformer contextualizer. We definitely need to figure out how this plays with having more general contextualizers... It is an important feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, seems hard to get the per-layer outputs of a multilayer pytorch LSTM / other contextualizers in a flexible fashion...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off the top of my head, a solution would be to modify /~https://github.com/allenai/allennlp/blob/master/allennlp/modules/seq2seq_encoders/seq2seq_encoder.py#L5 so that it takes a constructor arg "return_all_layers" (like with lazy in the DatasetReader) and then have the subclasses do the appropriate thing. Maybe also add forward as an explicit method in Seq2SeqEncoder as well in order to clearly document the behavior and the type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main trouble is that the PyTorch LSTM doesn't return the outputs for all layers and all timesteps?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point, @matt-gardner.

@nelson-liu, I should elaborate a bit. The pure object-oriented approach would be to have a MultiLayerSeq2SeqEncoder that subclasses the existing Seq2SeqEncoder. Annoyingly this would break anyone that's using a vanilla Seq2SeqEncoder currently. Another issue is that when many of these features are independent there is a risk of the class hierarchy becoming very deep resulting in, say, MultiLayerSeq2SeqEncoderWithFooAndBar, MultiLayerSeq2SeqEncoderWithFooAndNotBar, etc. There are a few ways to guard against this. First, one can make judicious use of defaults -- at the risk of weakening (if not breaking) the abstraction as Matt points out. Another option would be to use mixins. (Though that would also break existing users.) You can also adopt a sort of manual approach. This might involve exposing is_multi_layer for user code to query. This is ugly, but can be quite flexible and aid in backwards compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's figure out how/whether to obtain the extra layers and then proceed with caution as Matt suggests. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this mention. We don't need to handle the case that the contextualizer returns multiple layers here, because it will never do that during training.

I think really the problem here has stemmed from the fact we have been passing around the return_all_layers argument as a constructor parameter and not a runtime forward parameter (or as a separate method on a Contextualizer subclass of Seq2SeqEncoder.

Is the right API here this:

class Seq2SeqEncoder

class Contextualizer(Seq2SeqEncoder):

	# Idea 1:
    def forward(sequence: torch.Tensor[batch, sequence, embedding],
				return_all_layers = False) -> Union[torch.Tensor[batch, sequence, embedding], List[of the same thing]]

	# this way all Contextualizers can still function as `Seq2SeqEncoders` by default,
	# which is useful for training and perhaps for downstream use, say if you wanted to
	# fine tune an encoder and didn't want to do an elmo mixture. However, all
	# `ContextualTokenEmbedders` could call `Contextualizer.forward(return_all_layers=True)`.

	# Second idea:

    def forward(sequence: torch.Tensor[batch, sequence, embedding]) -> torch.Tensor[batch, sequence, embedding]

	def get_layers(sequence: torch.Tensor) -> List[torch.Tensor[batch, sequence, embedding]]:

	# A separate method which basically implements the functionality above.
	# This would get around the type problems of `forward` possibly returning lists
	# of tensors when used as a `Seq2SeqEncoder`.

This doesn't handle Nelson's case that the contextualizers might be stateful during training. I haven't thought hard about that because i'm not 100% convinced that we need to support it as part of a concrete API yet, but i'm happy to be convinced!

Copy link
Contributor Author

@nelson-liu nelson-liu Feb 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more thoughts later, but:

This doesn't handle Nelson's case that the contextualizers might be stateful during training. I haven't thought hard about that because i'm not 100% convinced that we need to support it as part of a concrete API yet, but i'm happy to be convinced!

How else would you do language modeling of contiguous text? At the very least, it'd be useful to add the (left-to-right) LM functionality to AllenNLP so people can easily train on contiguous-text datasets like the PTB or wikitext. I think having the bidirectional contiguous text LM would be useful as well, since I'm actually curious to run the experiment and find out whether contiguous text matters (I suspect it does, or at least doesn't hurt). Bidirectionality is definitely crucial to getting the best contextual representations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would #2716 help here?

isinstance(backward_contextual_embeddings, list)):
if len(contextual_embeddings) != len(backward_contextual_embeddings):
raise ValueError("Contextualizers produced outputs of different lengths")
for embedding_index, backward_embedding in enumerate(backward_contextual_embeddings):
contextual_embeddings[embedding_index] = torch.cat(
[contextual_embeddings[embedding_index], backward_embedding],
dim=-1)
elif (isinstance(contextual_embeddings, torch.Tensor) and
isinstance(backward_contextual_embeddings, torch.Tensor)):
contextual_embeddings = torch.cat(
[contextual_embeddings, backward_contextual_embeddings], dim=-1)
else:
raise ValueError("forward and backward contextualizer returned "
"different types. Output of forward_contextualizer "
f"has type f{type(contextual_embeddings)}, while"
"output of backward_contextualizer has type"
f"f{type(backward_contextual_embeddings)}")

return_dict = {}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
local config = import "experiment_unsampled.jsonnet";

config + {
"model"+: {
contextualizer :: super.contextualizer,
"forward_contextualizer": {
"type": "lstm",
"input_size": 16,
"hidden_size": 7,
"num_layers": 3,
"dropout": 0.1
},
"backward_contextualizer": {
"type": "gru",
"input_size": 16,
"hidden_size": 7,
"num_layers": 3,
"dropout": 0.1
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
local config = import "experiment_unidirectional_unsampled.jsonnet";

config + {
"model"+: {
contextualizer :: super.contextualizer,
"forward_contextualizer": {
"type": "lstm",
"input_size": 16,
"hidden_size": 7,
"num_layers": 3,
"dropout": 0.1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"contextualizer": {
"type": "lstm",
"bidirectional": true,
"num_layers": 3,
"num_layers": 1,
"input_size": 16,
"hidden_size": 7,
}
Expand Down
72 changes: 67 additions & 5 deletions allennlp/tests/models/language_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_mismatching_contextualizer_unidirectionality_throws_configuration_error
params = Params.from_file(self.param_file)
# Make the contextualizer unidirectionality wrong - it should be
# False to match the language model.
params["model"]["contextualizer"]["bidirectional"] = (not self.bidirectional)
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))
if "contextualizer" in params["model"]:
params["model"]["contextualizer"]["bidirectional"] = (not self.bidirectional)
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

class TestUnidirectionalLanguageModelUnsampled(TestUnidirectionalLanguageModel):
def setUp(self):
Expand All @@ -77,8 +78,34 @@ def test_unidirectional_language_model_can_train_save_and_load(self):
# they are not used.
self.ensure_model_can_train_save_and_load(
self.param_file, gradients_to_ignore={
"_contextualizer.feedforward_layer_norm_0.gamma",
"_contextualizer.feedforward_layer_norm_0.beta"})
"_forward_contextualizer.feedforward_layer_norm_0.gamma",
"_forward_contextualizer.feedforward_layer_norm_0.beta"})

class TestUnidirectionlLanguageModelForwardContextualizer(TestUnidirectionalLanguageModel):
def setUp(self):
super().setUp()

self.set_up_model(self.FIXTURES_ROOT / 'language_model' /
'experiment_unidirectional_forward.jsonnet',
self.FIXTURES_ROOT / 'language_model' / 'sentences.txt')

def test_unidirectional_no_forward_contextualizer_throws_configuration_error(self):
params = Params.from_file(self.param_file)
params["model"].pop("forward_contextualizer")
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

def test_unidirectional_with_backward_contextualizer_throws_configuration_error(self):
params = Params.from_file(self.param_file)
params["model"]["backward_contextualizer"] = {
"type": "gru",
"input_size": 16,
"hidden_size": 7,
"num_layers": 3,
"dropout": 0.1
}
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

class TestBidirectionalLanguageModel(TestUnidirectionalLanguageModel):
def setUp(self):
Expand All @@ -104,3 +131,38 @@ def setUp(self):

self.set_up_model(self.FIXTURES_ROOT / 'language_model' / 'experiment_transformer.jsonnet',
self.FIXTURES_ROOT / 'language_model' / 'sentences.txt')

class TestBidirectionalLanguageModelForwardBackward(TestBidirectionalLanguageModel):
def setUp(self):
super().setUp()

self.set_up_model(self.FIXTURES_ROOT / 'language_model' /
'experiment_forward_backward.jsonnet',
self.FIXTURES_ROOT / 'language_model' / 'sentences.txt')

def test_no_backward_contextualizer_throws_configuration_error(self):
params = Params.from_file(self.param_file)
# Remove the backward contextualizer, leaving only the forward
params["model"].pop("backward_contextualizer")
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

def test_bidirectional_backward_contextualizer_throws_configuration_error(self):
params = Params.from_file(self.param_file)
# Set bidirectional to true in backward
params["model"]["backward_contextualizer"]["bidirectional"] = True
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

def test_bidirectional_forward_contextualizer_throws_configuration_error(self):
params = Params.from_file(self.param_file)
# Set bidirectional to true in forward
params["model"]["forward_contextualizer"]["bidirectional"] = True
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))

def test_bidirectional_contextualizer_mismatched_output_throws_configuration_error(self):
params = Params.from_file(self.param_file)
params["model"]["forward_contextualizer"]["hidden_size"] = 8
with pytest.raises(ConfigurationError):
Model.from_params(vocab=self.vocab, params=params.get("model"))