Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'transformer-embedder' of /~https://github.com/matt-gardne…
Browse files Browse the repository at this point in the history
…r/allennlp into matt-gardner-transformer-embedder
  • Loading branch information
Ana committed Aug 27, 2019
2 parents 993034f + 70e92e8 commit 07bdc4a
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions allennlp/modules/token_embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
LanguageModelTokenEmbedder
from allennlp.modules.token_embedders.bag_of_word_counts_token_embedder import BagOfWordCountsTokenEmbedder
from allennlp.modules.token_embedders.pass_through_token_embedder import PassThroughTokenEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from overrides import overrides
from pytorch_transformers.modeling_auto import AutoModel
import torch

from allennlp.modules.token_embedders.token_embedder import TokenEmbedder


@TokenEmbedder.register("pretrained_transformer")
class PretrainedTransformerEmbedder(TokenEmbedder):
"""
Uses a pretrained model from ``pytorch-transformers`` as a ``TokenEmbedder``.
"""
def __init__(self, model_name: str) -> None:
super().__init__()
self.transformer_model = AutoModel.from_pretrained(model_name)
# I'm not sure if this works for all models; open an issue on github if you find a case
# where it doesn't work.
self.output_dim = self.transformer_model.config.hidden_size

@overrides
def get_output_dim(self):
return self.output_dim

def forward(self, token_ids: torch.LongTensor) -> torch.Tensor: # type: ignore
# pylint: disable=arguments-differ
return self.transformer_model(token_ids)[0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# pylint: disable=no-self-use,invalid-name
import torch

from allennlp.common import Params
from allennlp.modules.token_embedders import PretrainedTransformerEmbedder
from allennlp.common.testing import AllenNlpTestCase

class TestPretrainedTransformerEmbedder(AllenNlpTestCase):
def test_forward_runs_when_initialized_from_params(self):
# This code just passes things off to pytorch-transformers, so we only have a very simple
# test.
params = Params({'model_name': 'bert-base-uncased'})
embedder = PretrainedTransformerEmbedder.from_params(params)
tensor = torch.randint(0, 100, (1, 4))
output = embedder(tensor)
assert tuple(output.size()) == (1, 4, 768)
7 changes: 7 additions & 0 deletions doc/api/allennlp.modules.token_embedders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ allennlp.modules.token_embedders
* :ref:`LanguageModelTokenEmbedder<language-model-token-embedder>`
* :ref:`BagOfWordsCountsTokenEmbedder<bag-of-words-counts-token-embedder>`
* :ref:`PassThroughTokenEmbedder<pass-through-token-embedder>`
* :ref:`PretrainedTransformerEmbedder<pretrained-transformer-embedder>`

.. _token-embedder:
.. automodule:: allennlp.modules.token_embedders.token_embedder
Expand Down Expand Up @@ -77,3 +78,9 @@ allennlp.modules.token_embedders
:members:
:undoc-members:
:show-inheritance:

.. _pretrained-transformer-embedder:
.. automodule:: allennlp.modules.token_embedders.pretrained_transformer_embedder
:members:
:undoc-members:
:show-inheritance:

0 comments on commit 07bdc4a

Please sign in to comment.