Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Added a TokenEmbedder for use with pytorch-transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-gardner committed Aug 26, 2019
1 parent 0e872a0 commit 6ec74aa
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
1 change: 1 addition & 0 deletions allennlp/modules/token_embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
LanguageModelTokenEmbedder
from allennlp.modules.token_embedders.bag_of_word_counts_token_embedder import BagOfWordCountsTokenEmbedder
from allennlp.modules.token_embedders.pass_through_token_embedder import PassThroughTokenEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pytorch_transformers.modeling_auto import AutoModel
import torch

from allennlp.modules.token_embedders.token_embedder import TokenEmbedder


@TokenEmbedder.register("pretrained_transformer")
class PretrainedTransformerEmbedder(TokenEmbedder):
"""
Uses a pretrained model from ``pytorch-transformers`` as a ``TokenEmbedder``.
"""
def __init__(self, model_name: str) -> None:
super().__init__()
self.transformer_model = AutoModel.from_pretrained(model_name)

def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
return self.transformer_model(token_ids)[0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# pylint: disable=no-self-use,invalid-name
import torch

from allennlp.common import Params
from allennlp.modules.token_embedders import PretrainedTransformerEmbedder
from allennlp.common.testing import AllenNlpTestCase

class TestPretrainedTransformerEmbedder(AllenNlpTestCase):
def test_forward_runs_when_initialized_from_params(self):
# This code just passes things off to pytorch-transformers, so we only have a very simple
# test.
params = Params({'model_name': 'bert-base-uncased'})
embedder = PretrainedTransformerEmbedder.from_params(params)
tensor = torch.randint(0, 100, (1, 4))
output = embedder(tensor)
assert tuple(output.size()) == (1, 4, 768)

0 comments on commit 6ec74aa

Please sign in to comment.