This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a LanguageModelHead abstraction (#3200)
* Modules and docs * Added tests * Docstrings * pylint * moved linear layer to tests * add todos about caching * fix import... * doc
- Loading branch information
1 parent
370d512
commit 8c06c4b
Showing
11 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead | ||
from allennlp.modules.language_model_heads.bert import BertLanguageModelHead | ||
from allennlp.modules.language_model_heads.gpt2 import Gpt2LanguageModelHead |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from overrides import overrides | ||
from pytorch_transformers import BertConfig, BertForMaskedLM | ||
import torch | ||
|
||
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead | ||
|
||
|
||
@LanguageModelHead.register('bert') | ||
class BertLanguageModelHead(LanguageModelHead): | ||
""" | ||
Loads just the LM head from ``pytorch_transformers.BertForMaskedLM``. It was easiest to load | ||
the entire model before only pulling out the head, so this is a bit slower than it could be, | ||
but for practical use in a model, the few seconds of extra loading time is probably not a big | ||
deal. | ||
""" | ||
def __init__(self, model_name: str) -> None: | ||
super().__init__() | ||
config = BertConfig.from_pretrained(model_name) | ||
self.input_dim = config.hidden_size | ||
self.output_dim = config.vocab_size | ||
# TODO(mattg): It's possible that we could use some kind of cache like we have in | ||
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we | ||
# would only load the BERT weights once. Though, it's not clear how to do that here, as we | ||
# need to load `BertForMaskedLM`, not just `BertModel`... | ||
bert_model = BertForMaskedLM.from_pretrained(model_name) | ||
self.bert_lm_head = bert_model.cls # pylint: disable=no-member | ||
|
||
@overrides | ||
def get_input_dim(self) -> int: | ||
return self.input_dim | ||
|
||
@overrides | ||
def get_output_dim(self) -> int: | ||
return self.output_dim | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
return self.bert_lm_head(hidden_states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from overrides import overrides | ||
from pytorch_transformers import GPT2Config, GPT2LMHeadModel | ||
import torch | ||
|
||
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead | ||
|
||
|
||
@LanguageModelHead.register('gpt2') | ||
class Gpt2LanguageModelHead(LanguageModelHead): | ||
""" | ||
Loads just the LM head from ``pytorch_transformers.GPT2LMHeadModel``. It was easiest to load | ||
the entire model before only pulling out the head, so this is a bit slower than it could be, | ||
but for practical use in a model, the few seconds of extra loading time is probably not a big | ||
deal. | ||
""" | ||
def __init__(self, model_name: str) -> None: | ||
super().__init__() | ||
config = GPT2Config.from_pretrained(model_name) | ||
self.input_dim = config.hidden_size | ||
self.output_dim = config.vocab_size | ||
# TODO(mattg): It's possible that we could use some kind of cache like we have in | ||
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we | ||
# would only load the GPT2 weights once. Though, it's not clear how to do that here, as we | ||
# need to load `GPT2LMHeadModel`, not just `GPT2Model`... | ||
gpt2_model = GPT2LMHeadModel.from_pretrained(model_name) | ||
self.gpt2_lm_head = gpt2_model.lm_head # pylint: disable=no-member | ||
|
||
@overrides | ||
def get_input_dim(self) -> int: | ||
return self.input_dim | ||
|
||
@overrides | ||
def get_output_dim(self) -> int: | ||
return self.output_dim | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
return self.gpt2_lm_head(hidden_states) |
19 changes: 19 additions & 0 deletions
19
allennlp/modules/language_model_heads/language_model_head.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
from allennlp.common import Registrable | ||
|
||
|
||
class LanguageModelHead(torch.nn.Module, Registrable): | ||
""" | ||
A ``LanguageModelHead`` encapsulates a function that goes from some hidden state to logits over | ||
a vocabulary. | ||
""" | ||
def get_input_dim(self) -> int: | ||
raise NotImplementedError | ||
|
||
def get_output_dim(self) -> int: | ||
raise NotImplementedError | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # type: ignore | ||
# pylint: disable=arguments-differ | ||
raise NotImplementedError |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# pylint: disable=invalid-name,no-self-use,protected-access | ||
import torch | ||
|
||
from allennlp.common import Params | ||
from allennlp.common.testing.test_case import AllenNlpTestCase | ||
from allennlp.modules.language_model_heads import LanguageModelHead, BertLanguageModelHead | ||
|
||
|
||
class TestBertLanguageModelHead(AllenNlpTestCase): | ||
def test_can_init_and_run(self): | ||
# The LM head code reads a module from somewhere else; we're basically just testing here | ||
# that we can initialize the expected model `from_params`. | ||
head = LanguageModelHead.from_params(Params({"type": "bert", "model_name": "bert-base-uncased"})) | ||
assert isinstance(head, BertLanguageModelHead) | ||
assert head.get_input_dim() == 768 | ||
assert head.get_output_dim() == 30522 | ||
tensor = torch.rand(1, 768) | ||
logits = head(tensor) | ||
assert tuple(logits.size()) == (1, 30522) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# pylint: disable=invalid-name,no-self-use,protected-access | ||
import torch | ||
|
||
from allennlp.common import Params | ||
from allennlp.common.testing.test_case import AllenNlpTestCase | ||
from allennlp.modules.language_model_heads import LanguageModelHead, Gpt2LanguageModelHead | ||
|
||
|
||
class TestGpt2LanguageModelHead(AllenNlpTestCase): | ||
def test_can_init_and_run(self): | ||
# The LM head code reads a module from somewhere else; we're basically just testing here | ||
# that we can initialize the expected model `from_params`. | ||
head = LanguageModelHead.from_params(Params({"type": "gpt2", "model_name": "gpt2"})) | ||
assert isinstance(head, Gpt2LanguageModelHead) | ||
assert head.get_input_dim() == 768 | ||
assert head.get_output_dim() == 50257 | ||
tensor = torch.rand(1, 768) | ||
logits = head(tensor) | ||
assert tuple(logits.size()) == (1, 50257) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.data import Vocabulary | ||
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead | ||
|
||
|
||
@LanguageModelHead.register('linear') | ||
class LinearLanguageModelHead(LanguageModelHead): | ||
""" | ||
Uses ``torch.nn.Linear`` as a language model head. Does nothing else fancy. This was intended | ||
largely for testing code with small models and simple components. It's likely that you would | ||
want something nicer for actually training a language model, such as tying weights with an | ||
input embedding, or an adaptive softmax, or something. But, if you find this class useful for | ||
something you're doing and want it moved into the repo, open an issue on github. | ||
""" | ||
def __init__(self, | ||
vocab: Vocabulary, | ||
input_dim: int, | ||
vocab_namespace: str) -> None: | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.output_dim = vocab.get_vocab_size(vocab_namespace) | ||
self.linear = torch.nn.Linear(self.input_dim, self.output_dim) | ||
|
||
@overrides | ||
def get_input_dim(self) -> int: | ||
return self.input_dim | ||
|
||
@overrides | ||
def get_output_dim(self) -> int: | ||
return self.output_dim | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
return self.linear(hidden_states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
allennlp.modules.language_model_heads | ||
===================================== | ||
|
||
.. automodule:: allennlp.modules.language_model_heads | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
.. automodule:: allennlp.modules.language_model_heads.language_model_head | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
.. automodule:: allennlp.modules.language_model_heads.bert | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
.. automodule:: allennlp.modules.language_model_heads.gpt2 | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters