-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Decompose LanguageModel contextualizer into forward_ and backward_ contextualizer #2438
Changes from all commits
32a810c
1631caf
9cdf898
b4d976d
6db32d7
70262a5
493f798
3573b68
471d5ea
6fff544
3b1918c
2f7173e
15e0234
3801ef7
18f233f
0db58dd
136a66b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from typing import Dict, List, Tuple, Union | ||
import warnings | ||
|
||
import torch | ||
import numpy as np | ||
|
@@ -74,6 +75,23 @@ class LanguageModel(Model): | |
contextualizer: ``Seq2SeqEncoder`` | ||
Used to "contextualize" the embeddings. As described above, | ||
this encoder must not cheat by peeking ahead. | ||
|
||
.. deprecated:: 0.8.2 | ||
``contextualizer`` was deprecated in version 0.8.2 . It was | ||
replaced with two more flexible arguments: ``forward_contextualizer`` | ||
and ``backward_contextualizer``, in order to enable bidirectional | ||
language modeling of contiguous text. It will be removed in version 0.10 . | ||
|
||
forward_contextualizer: ``Seq2SeqEncoder`` | ||
Used to "contextualize" the embeddings for a forward-direction LM. | ||
As described above, this encoder must not cheat by peeking ahead. | ||
backward_contextualizer: ``Seq2SeqEncoder`` | ||
Used to "contextualize" the embeddings for a backward-direction LM. | ||
The contextualizer should operate from left to right; the the order of the | ||
text in the backward inputs is assumed to have been flipped (e.g., by your | ||
DatasetReader). If provided, the size of its output must match that of | ||
the ``forward_contextualizer``. | ||
As described above, this encoder must not cheat by peeking ahead. | ||
dropout: ``float``, optional (default: None) | ||
If specified, dropout is applied to the contextualized embeddings before computation of | ||
the softmax. The contextualized embeddings themselves are returned without dropout. | ||
|
@@ -91,7 +109,9 @@ class LanguageModel(Model): | |
def __init__(self, | ||
vocab: Vocabulary, | ||
text_field_embedder: TextFieldEmbedder, | ||
contextualizer: Seq2SeqEncoder, | ||
contextualizer: Seq2SeqEncoder = None, | ||
forward_contextualizer: Seq2SeqEncoder = None, | ||
backward_contextualizer: Seq2SeqEncoder = None, | ||
dropout: float = None, | ||
num_samples: int = None, | ||
sparse_embeddings: bool = False, | ||
|
@@ -100,22 +120,98 @@ def __init__(self, | |
super().__init__(vocab) | ||
self._text_field_embedder = text_field_embedder | ||
|
||
if contextualizer.is_bidirectional() is not bidirectional: | ||
# Only true when contextualizer is non-None and bidirectional is True | ||
self._use_contextualizer_arg = False | ||
if contextualizer is not None and (forward_contextualizer is not None or | ||
backward_contextualizer is not None): | ||
raise ConfigurationError( | ||
"Bidirectionality of contextualizer must match bidirectionality of " | ||
"language model. " | ||
f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, " | ||
f"language model bidirectional: {bidirectional}") | ||
|
||
self._contextualizer = contextualizer | ||
"Cannot provide both contextualizer and either " | ||
"forward_contextualizer or backward_contextualizer.") | ||
|
||
if contextualizer is not None: | ||
warnings.warn("``contextualizer`` was deprecated in version 0.8.2 . It was " | ||
"replaced with two more flexible arguments: " | ||
"``forward_contextualizer`` and ``backward_contextualizer``. " | ||
"It will be removed in version 0.10 .", | ||
DeprecationWarning) | ||
if contextualizer.is_bidirectional() is not bidirectional: | ||
raise ConfigurationError( | ||
"Bidirectionality of contextualizer must match bidirectionality " | ||
"of language model. " | ||
f"Contextualizer bidirectional: {contextualizer.is_bidirectional()}, " | ||
f"language model bidirectional: {bidirectional}") | ||
if contextualizer.is_bidirectional(): | ||
warnings.warn( | ||
"When using a bidirectional contextualizer, it's crucial that " | ||
"the contextualizer does not cheat by looking ahead. For " | ||
"instance, if you're using a multi-layer bidirectional RNN " | ||
"here, the model is cheating because layers >= 2 use opposite " | ||
"direection inputs (a single-layer bidirectional RNN is " | ||
"thus fine). See the BidirectionalLanguageModelTransformer " | ||
"for an example of how to properly ensure that a multilayer " | ||
"bidirectional contextualizer doesn't inadvertently cheat, or " | ||
"provide values for the forward_contextualizer and " | ||
"backward_contextualizer arguments instead.") | ||
self._use_contextualizer_arg = True | ||
else: | ||
# Unidirectional LM with unidirectional contextualizer, so just set | ||
# forward_contextualizer to contextualizer. | ||
forward_contextualizer = contextualizer | ||
contextualizer = None | ||
# If self._use_contextualizer_arg is True, this is non-None. Else, it is None. | ||
self._contextualizer = contextualizer | ||
|
||
# ``contextualizer`` logic handled, do error checking for | ||
# forward_contextualizer and backward_contextualizer | ||
if bidirectional and (forward_contextualizer is None or | ||
backward_contextualizer is None): | ||
if not self._use_contextualizer_arg: | ||
raise ConfigurationError( | ||
"LanguageModel bidirectional is True, but did not " | ||
"provide forward_contextualizer and backward_contextualizer. " | ||
f"Got forward_contextualizer: {forward_contextualizer} and " | ||
f"backward_contextualizer: {backward_contextualizer}") | ||
if not self._use_contextualizer_arg and forward_contextualizer is None: | ||
raise ConfigurationError( | ||
"The forward_contextualizer argument is required.") | ||
if not bidirectional and backward_contextualizer is not None: | ||
raise ConfigurationError( | ||
"LanguageModel bidirectional is False, so " | ||
"backward_contextualizer should not be provided." | ||
f"Got backward_contextualizer: {backward_contextualizer}") | ||
# Ensure that forward_contextualizer and backward_contextualizer | ||
# are unidirectional | ||
if forward_contextualizer and forward_contextualizer.is_bidirectional(): | ||
raise ConfigurationError("forward_contextualizer should not be " | ||
"bidirectional.") | ||
if backward_contextualizer and backward_contextualizer.is_bidirectional(): | ||
raise ConfigurationError("backward_contextualizer should not be " | ||
"bidirectional.") | ||
|
||
self._forward_contextualizer = forward_contextualizer | ||
self._backward_contextualizer = backward_contextualizer | ||
self._bidirectional = bidirectional | ||
|
||
# The dimension for making predictions just in the forward | ||
# (or backward) direction. | ||
# They must be the same. TODO (nfliu): relax this assumption | ||
if self._bidirectional: | ||
self._forward_dim = contextualizer.get_output_dim() // 2 | ||
if self._use_contextualizer_arg: | ||
self._forward_dim = self._contextualizer.get_output_dim() // 2 | ||
else: | ||
if (self._forward_contextualizer.get_output_dim() != | ||
self._backward_contextualizer.get_output_dim()): | ||
raise ConfigurationError( | ||
"forward_contextualizer and backward_contextualizer " | ||
"must have the same output dimension. " | ||
"forward_contextualizer output dimension is " | ||
f"{self._forward_contextualizer.get_output_dim()}, while" | ||
"backward_contextualizer output dimension is " | ||
f"{self._forward_contextualizer.get_output_dim()}") | ||
self._forward_dim = self._forward_contextualizer.get_output_dim() | ||
else: | ||
self._forward_dim = contextualizer.get_output_dim() | ||
# If bidirectional is False, self._use_contextualizer_arg is False. | ||
self._forward_dim = self._forward_contextualizer.get_output_dim() | ||
|
||
# TODO(joelgrus): more sampled softmax configuration options, as needed. | ||
if num_samples is not None: | ||
|
@@ -264,9 +360,34 @@ def forward(self, # type: ignore | |
embeddings = self._text_field_embedder(source) | ||
|
||
# Either the top layer or all layers. | ||
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer( | ||
embeddings, mask | ||
) | ||
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = None | ||
if self._use_contextualizer_arg: | ||
contextual_embeddings = self._contextualizer(embeddings, mask) | ||
else: | ||
contextual_embeddings = self._forward_contextualizer(embeddings, mask) | ||
if self._bidirectional: | ||
backward_contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = ( | ||
self._backward_contextualizer(embeddings, mask)) | ||
# Concatenate the backward contextual embeddings to the | ||
# forward contextual embeddings | ||
if (isinstance(contextual_embeddings, list) and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @brendan-ai2 i'm struggling to test this case (where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Blargh. :/ The contextualizer can be made to return all layers for embedding. There basically isn't a real API in place right now as this is specific to the transformer contextualizer. We definitely need to figure out how this plays with having more general contextualizers... It is an important feature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, seems hard to get the per-layer outputs of a multilayer pytorch LSTM / other contextualizers in a flexible fashion... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Off the top of my head, a solution would be to modify /~https://github.com/allenai/allennlp/blob/master/allennlp/modules/seq2seq_encoders/seq2seq_encoder.py#L5 so that it takes a constructor arg "return_all_layers" (like with lazy in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the main trouble is that the PyTorch LSTM doesn't return the outputs for all layers and all timesteps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent point, @matt-gardner. @nelson-liu, I should elaborate a bit. The pure object-oriented approach would be to have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's figure out how/whether to obtain the extra layers and then proceed with caution as Matt suggests. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I missed this mention. We don't need to handle the case that the contextualizer returns multiple layers here, because it will never do that during training. I think really the problem here has stemmed from the fact we have been passing around the Is the right API here this: class Seq2SeqEncoder
class Contextualizer(Seq2SeqEncoder):
# Idea 1:
def forward(sequence: torch.Tensor[batch, sequence, embedding],
return_all_layers = False) -> Union[torch.Tensor[batch, sequence, embedding], List[of the same thing]]
# this way all Contextualizers can still function as `Seq2SeqEncoders` by default,
# which is useful for training and perhaps for downstream use, say if you wanted to
# fine tune an encoder and didn't want to do an elmo mixture. However, all
# `ContextualTokenEmbedders` could call `Contextualizer.forward(return_all_layers=True)`.
# Second idea:
def forward(sequence: torch.Tensor[batch, sequence, embedding]) -> torch.Tensor[batch, sequence, embedding]
def get_layers(sequence: torch.Tensor) -> List[torch.Tensor[batch, sequence, embedding]]:
# A separate method which basically implements the functionality above.
# This would get around the type problems of `forward` possibly returning lists
# of tensors when used as a `Seq2SeqEncoder`. This doesn't handle Nelson's case that the contextualizers might be stateful during training. I haven't thought hard about that because i'm not 100% convinced that we need to support it as part of a concrete API yet, but i'm happy to be convinced! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more thoughts later, but:
How else would you do language modeling of contiguous text? At the very least, it'd be useful to add the (left-to-right) LM functionality to AllenNLP so people can easily train on contiguous-text datasets like the PTB or wikitext. I think having the bidirectional contiguous text LM would be useful as well, since I'm actually curious to run the experiment and find out whether contiguous text matters (I suspect it does, or at least doesn't hurt). Bidirectionality is definitely crucial to getting the best contextual representations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would #2716 help here? |
||
isinstance(backward_contextual_embeddings, list)): | ||
if len(contextual_embeddings) != len(backward_contextual_embeddings): | ||
raise ValueError("Contextualizers produced outputs of different lengths") | ||
for embedding_index, backward_embedding in enumerate(backward_contextual_embeddings): | ||
contextual_embeddings[embedding_index] = torch.cat( | ||
[contextual_embeddings[embedding_index], backward_embedding], | ||
dim=-1) | ||
elif (isinstance(contextual_embeddings, torch.Tensor) and | ||
isinstance(backward_contextual_embeddings, torch.Tensor)): | ||
contextual_embeddings = torch.cat( | ||
[contextual_embeddings, backward_contextual_embeddings], dim=-1) | ||
else: | ||
raise ValueError("forward and backward contextualizer returned " | ||
"different types. Output of forward_contextualizer " | ||
f"has type f{type(contextual_embeddings)}, while" | ||
"output of backward_contextualizer has type" | ||
f"f{type(backward_contextual_embeddings)}") | ||
|
||
return_dict = {} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
local config = import "experiment_unsampled.jsonnet"; | ||
|
||
config + { | ||
"model"+: { | ||
contextualizer :: super.contextualizer, | ||
"forward_contextualizer": { | ||
"type": "lstm", | ||
"input_size": 16, | ||
"hidden_size": 7, | ||
"num_layers": 3, | ||
"dropout": 0.1 | ||
}, | ||
"backward_contextualizer": { | ||
"type": "gru", | ||
"input_size": 16, | ||
"hidden_size": 7, | ||
"num_layers": 3, | ||
"dropout": 0.1 | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
local config = import "experiment_unidirectional_unsampled.jsonnet"; | ||
|
||
config + { | ||
"model"+: { | ||
contextualizer :: super.contextualizer, | ||
"forward_contextualizer": { | ||
"type": "lstm", | ||
"input_size": 16, | ||
"hidden_size": 7, | ||
"num_layers": 3, | ||
"dropout": 0.1 | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: To me this seems like we are ripping out quite a good api for one which is more complicated. We already have a trained
BidirectionalLanguageModel
which uses a bidirectional transformer, which is very useful to many people (e.g github issues/ Swabha is using it in her research), which makes me unsure that depreciating a key component of it is a good idea. Do we know that bidirectional language modeling of contiguous text is worth it? I can't imagine a downstream task where you actually need unlimited context (and even if there was one, how that would practically work). Is it worth ripping this into a separate repo, making the changes and double checking that contiguous text does something useful?Additionally, we now have a
BidirectionalLanguageModel
subclass which is subsumed by this code (I think). This is "bad code smell" to me when we have a subclass which only differs in constructor arguments and we only have a single example of how those arguments might differ (it triggers myMultiClassMultipleChoiceMemoryNetwork(MemoryNetwork, MultiClass, MultipleChoice)
sense from DeepQa).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving the API decision to others, here's some motivation for enabling modeling of longer contexts: https://ai.googleblog.com/2019/01/transformer-xl-unleashing-potential-of.html. They basically make a stateful transformer, showing pretty big gains.