Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adding a LanguageModelHead abstraction (#3200)
Browse files Browse the repository at this point in the history
* Modules and docs

* Added tests

* Docstrings

* pylint

* moved linear layer to tests

* add todos about caching

* fix import...

* doc
  • Loading branch information
matt-gardner authored Aug 27, 2019
1 parent 370d512 commit 8c06c4b
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 0 deletions.
1 change: 1 addition & 0 deletions allennlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from allennlp.modules.input_variational_dropout import InputVariationalDropout
from allennlp.modules.bimpm_matching import BiMpmMatching
from allennlp.modules.residual_with_layer_dropout import ResidualWithLayerDropout
from allennlp.modules.language_model_heads import LanguageModelHead
3 changes: 3 additions & 0 deletions allennlp/modules/language_model_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead
from allennlp.modules.language_model_heads.bert import BertLanguageModelHead
from allennlp.modules.language_model_heads.gpt2 import Gpt2LanguageModelHead
37 changes: 37 additions & 0 deletions allennlp/modules/language_model_heads/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from overrides import overrides
from pytorch_transformers import BertConfig, BertForMaskedLM
import torch

from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead


@LanguageModelHead.register('bert')
class BertLanguageModelHead(LanguageModelHead):
"""
Loads just the LM head from ``pytorch_transformers.BertForMaskedLM``. It was easiest to load
the entire model before only pulling out the head, so this is a bit slower than it could be,
but for practical use in a model, the few seconds of extra loading time is probably not a big
deal.
"""
def __init__(self, model_name: str) -> None:
super().__init__()
config = BertConfig.from_pretrained(model_name)
self.input_dim = config.hidden_size
self.output_dim = config.vocab_size
# TODO(mattg): It's possible that we could use some kind of cache like we have in
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we
# would only load the BERT weights once. Though, it's not clear how to do that here, as we
# need to load `BertForMaskedLM`, not just `BertModel`...
bert_model = BertForMaskedLM.from_pretrained(model_name)
self.bert_lm_head = bert_model.cls # pylint: disable=no-member

@overrides
def get_input_dim(self) -> int:
return self.input_dim

@overrides
def get_output_dim(self) -> int:
return self.output_dim

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.bert_lm_head(hidden_states)
37 changes: 37 additions & 0 deletions allennlp/modules/language_model_heads/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from overrides import overrides
from pytorch_transformers import GPT2Config, GPT2LMHeadModel
import torch

from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead


@LanguageModelHead.register('gpt2')
class Gpt2LanguageModelHead(LanguageModelHead):
"""
Loads just the LM head from ``pytorch_transformers.GPT2LMHeadModel``. It was easiest to load
the entire model before only pulling out the head, so this is a bit slower than it could be,
but for practical use in a model, the few seconds of extra loading time is probably not a big
deal.
"""
def __init__(self, model_name: str) -> None:
super().__init__()
config = GPT2Config.from_pretrained(model_name)
self.input_dim = config.hidden_size
self.output_dim = config.vocab_size
# TODO(mattg): It's possible that we could use some kind of cache like we have in
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we
# would only load the GPT2 weights once. Though, it's not clear how to do that here, as we
# need to load `GPT2LMHeadModel`, not just `GPT2Model`...
gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
self.gpt2_lm_head = gpt2_model.lm_head # pylint: disable=no-member

@overrides
def get_input_dim(self) -> int:
return self.input_dim

@overrides
def get_output_dim(self) -> int:
return self.output_dim

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.gpt2_lm_head(hidden_states)
19 changes: 19 additions & 0 deletions allennlp/modules/language_model_heads/language_model_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from allennlp.common import Registrable


class LanguageModelHead(torch.nn.Module, Registrable):
"""
A ``LanguageModelHead`` encapsulates a function that goes from some hidden state to logits over
a vocabulary.
"""
def get_input_dim(self) -> int:
raise NotImplementedError

def get_output_dim(self) -> int:
raise NotImplementedError

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # type: ignore
# pylint: disable=arguments-differ
raise NotImplementedError
Empty file.
19 changes: 19 additions & 0 deletions allennlp/tests/modules/language_model_heads/bert_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# pylint: disable=invalid-name,no-self-use,protected-access
import torch

from allennlp.common import Params
from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.modules.language_model_heads import LanguageModelHead, BertLanguageModelHead


class TestBertLanguageModelHead(AllenNlpTestCase):
def test_can_init_and_run(self):
# The LM head code reads a module from somewhere else; we're basically just testing here
# that we can initialize the expected model `from_params`.
head = LanguageModelHead.from_params(Params({"type": "bert", "model_name": "bert-base-uncased"}))
assert isinstance(head, BertLanguageModelHead)
assert head.get_input_dim() == 768
assert head.get_output_dim() == 30522
tensor = torch.rand(1, 768)
logits = head(tensor)
assert tuple(logits.size()) == (1, 30522)
19 changes: 19 additions & 0 deletions allennlp/tests/modules/language_model_heads/gpt2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# pylint: disable=invalid-name,no-self-use,protected-access
import torch

from allennlp.common import Params
from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.modules.language_model_heads import LanguageModelHead, Gpt2LanguageModelHead


class TestGpt2LanguageModelHead(AllenNlpTestCase):
def test_can_init_and_run(self):
# The LM head code reads a module from somewhere else; we're basically just testing here
# that we can initialize the expected model `from_params`.
head = LanguageModelHead.from_params(Params({"type": "gpt2", "model_name": "gpt2"}))
assert isinstance(head, Gpt2LanguageModelHead)
assert head.get_input_dim() == 768
assert head.get_output_dim() == 50257
tensor = torch.rand(1, 768)
logits = head(tensor)
assert tuple(logits.size()) == (1, 50257)
35 changes: 35 additions & 0 deletions allennlp/tests/modules/language_model_heads/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from overrides import overrides
import torch

from allennlp.data import Vocabulary
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead


@LanguageModelHead.register('linear')
class LinearLanguageModelHead(LanguageModelHead):
"""
Uses ``torch.nn.Linear`` as a language model head. Does nothing else fancy. This was intended
largely for testing code with small models and simple components. It's likely that you would
want something nicer for actually training a language model, such as tying weights with an
input embedding, or an adaptive softmax, or something. But, if you find this class useful for
something you're doing and want it moved into the repo, open an issue on github.
"""
def __init__(self,
vocab: Vocabulary,
input_dim: int,
vocab_namespace: str) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = vocab.get_vocab_size(vocab_namespace)
self.linear = torch.nn.Linear(self.input_dim, self.output_dim)

@overrides
def get_input_dim(self) -> int:
return self.input_dim

@overrides
def get_output_dim(self) -> int:
return self.output_dim

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.linear(hidden_states)
22 changes: 22 additions & 0 deletions doc/api/allennlp.modules.language_model_heads.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
allennlp.modules.language_model_heads
=====================================

.. automodule:: allennlp.modules.language_model_heads
:members:
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.modules.language_model_heads.language_model_head
:members:
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.modules.language_model_heads.bert
:members:
:undoc-members:
:show-inheritance:

.. automodule:: allennlp.modules.language_model_heads.gpt2
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions doc/api/allennlp.modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ allennlp.modules
allennlp.modules.lstm_cell_with_projection
allennlp.modules.elmo
allennlp.modules.elmo_lstm
allennlp.modules.language_model_heads
allennlp.modules.conditional_random_field
allennlp.modules.feedforward
allennlp.modules.highway
Expand Down

0 comments on commit 8c06c4b

Please sign in to comment.