From 030e28c9e11d85ab674442d635d0e2f66548f211 Mon Sep 17 00:00:00 2001 From: Ana Date: Mon, 26 Aug 2019 18:46:32 -0700 Subject: [PATCH] Revert "Revert "Merge branch 'matt-gardner-transformer-embedder'"" This reverts commit 6e1e3713155c74e041cde8bb8f23535b5c84ec9c. --- allennlp/modules/token_embedders/__init__.py | 1 + .../pretrained_transformer_embedder.py | 26 +++++++++++++++++++ .../pretrained_transformer_embedder_test.py | 16 ++++++++++++ doc/api/allennlp.modules.token_embedders.rst | 7 +++++ 4 files changed, 50 insertions(+) create mode 100644 allennlp/modules/token_embedders/pretrained_transformer_embedder.py create mode 100644 allennlp/tests/modules/token_embedders/pretrained_transformer_embedder_test.py diff --git a/allennlp/modules/token_embedders/__init__.py b/allennlp/modules/token_embedders/__init__.py index 797cedeb58b..260bc40c12a 100644 --- a/allennlp/modules/token_embedders/__init__.py +++ b/allennlp/modules/token_embedders/__init__.py @@ -16,3 +16,4 @@ LanguageModelTokenEmbedder from allennlp.modules.token_embedders.bag_of_word_counts_token_embedder import BagOfWordCountsTokenEmbedder from allennlp.modules.token_embedders.pass_through_token_embedder import PassThroughTokenEmbedder +from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py new file mode 100644 index 00000000000..dd48938a426 --- /dev/null +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -0,0 +1,26 @@ +from overrides import overrides +from pytorch_transformers.modeling_auto import AutoModel +import torch + +from allennlp.modules.token_embedders.token_embedder import TokenEmbedder + + +@TokenEmbedder.register("pretrained_transformer") +class PretrainedTransformerEmbedder(TokenEmbedder): + """ + Uses a pretrained model from ``pytorch-transformers`` as a ``TokenEmbedder``. + """ + def __init__(self, model_name: str) -> None: + super().__init__() + self.transformer_model = AutoModel.from_pretrained(model_name) + # I'm not sure if this works for all models; open an issue on github if you find a case + # where it doesn't work. + self.output_dim = self.transformer_model.config.hidden_size + + @overrides + def get_output_dim(self): + return self.output_dim + + def forward(self, token_ids: torch.LongTensor) -> torch.Tensor: # type: ignore + # pylint: disable=arguments-differ + return self.transformer_model(token_ids)[0] diff --git a/allennlp/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/allennlp/tests/modules/token_embedders/pretrained_transformer_embedder_test.py new file mode 100644 index 00000000000..e32a1919492 --- /dev/null +++ b/allennlp/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -0,0 +1,16 @@ +# pylint: disable=no-self-use,invalid-name +import torch + +from allennlp.common import Params +from allennlp.modules.token_embedders import PretrainedTransformerEmbedder +from allennlp.common.testing import AllenNlpTestCase + +class TestPretrainedTransformerEmbedder(AllenNlpTestCase): + def test_forward_runs_when_initialized_from_params(self): + # This code just passes things off to pytorch-transformers, so we only have a very simple + # test. + params = Params({'model_name': 'bert-base-uncased'}) + embedder = PretrainedTransformerEmbedder.from_params(params) + tensor = torch.randint(0, 100, (1, 4)) + output = embedder(tensor) + assert tuple(output.size()) == (1, 4, 768) diff --git a/doc/api/allennlp.modules.token_embedders.rst b/doc/api/allennlp.modules.token_embedders.rst index 0a5cdd32e24..36945107af5 100644 --- a/doc/api/allennlp.modules.token_embedders.rst +++ b/doc/api/allennlp.modules.token_embedders.rst @@ -17,6 +17,7 @@ allennlp.modules.token_embedders * :ref:`LanguageModelTokenEmbedder` * :ref:`BagOfWordsCountsTokenEmbedder` * :ref:`PassThroughTokenEmbedder` +* :ref:`PretrainedTransformerEmbedder` .. _token-embedder: .. automodule:: allennlp.modules.token_embedders.token_embedder @@ -77,3 +78,9 @@ allennlp.modules.token_embedders :members: :undoc-members: :show-inheritance: + +.. _pretrained-transformer-embedder: +.. automodule:: allennlp.modules.token_embedders.pretrained_transformer_embedder + :members: + :undoc-members: + :show-inheritance: