Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Transformer toolkit updates (#5270)
Browse files Browse the repository at this point in the history
* Fix duplicate line

* Easy access to the output dimension of an activation layer

* Take an ignore an attention mask in TransformerEmbeddings

* Make it so a pooler can be derived from a huggingface module

* Pooler that can load from a transformer module

* Changelog

* Update transformer_embeddings.py

* Productivity through formatting

* Don't break positional arguments

* Some mode module names

* Remove _get_input_arguments()

Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
  • Loading branch information
dirkgr and AkshitaB authored Jun 21, 2021
1 parent 6af9069 commit c8b8ed3
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `on_backward` training callback which allows for control over backpropagation and gradient manipulation.
- Added `AdversarialBiasMitigator`, a Model wrapper to adversarially mitigate biases in predictions produced by a pretrained model for a downstream task.
- Added `which_loss` parameter to `ensure_model_can_train_save_and_load` in `ModelTestCase` to specify which loss to test.
- The activation layer in the transformer toolkit now can be queried for its output dimension.
- `TransformerEmbeddings` now takes, but ignores, a parameter for the attention mask. This is needed for compatibility with some other modules that get called the same way and use the mask.
- `TransformerPooler` can now be instantiated from a pretrained transformer module, just like the other modules in the transformer toolkit.

### Fixed

Expand Down
3 changes: 3 additions & 0 deletions allennlp/modules/transformer/activation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(
self.act_fn = activation
self.pool = pool

def get_output_dim(self) -> int:
return self.dense.out_features

def forward(self, hidden_states):
if self.pool:
hidden_states = hidden_states[:, 0]
Expand Down
6 changes: 4 additions & 2 deletions allennlp/modules/transformer/transformer_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TransformerEmbeddings(Embeddings):
Optionally apply a linear transform after the dropout, projecting to `output_size`.
"""

_pretrained_relevant_module = ["embeddings", "bert.embeddings"]
_pretrained_relevant_module = ["embeddings", "bert.embeddings", "roberta.embeddings"]
_pretrained_mapping = {
"LayerNorm": "layer_norm",
"word_embeddings": "embeddings.word_embeddings",
Expand All @@ -113,7 +113,6 @@ class TransformerEmbeddings(Embeddings):
# Albert is a special case. A linear projection is applied to the embeddings,
# but that linear transformation lives in the encoder.
"albert.embeddings.LayerNorm": "layer_norm",
"albert.embeddings.LayerNorm": "layer_norm",
"albert.embeddings.word_embeddings": "embeddings.word_embeddings",
"albert.embeddings.position_embeddings": "embeddings.position_embeddings",
"albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings",
Expand Down Expand Up @@ -163,12 +162,15 @@ def forward( # type: ignore
input_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:

"""
# Parameters
input_ids : `torch.Tensor`
Shape `batch_size x seq_len`
attention_mask : `torch.Tensor`
Shape `batch_size x seq_len`. This parameter is ignored, but it is here for compatibility.
token_type_ids : `torch.Tensor`, optional
Shape `batch_size x seq_len`
position_ids : `torch.Tensor`, optional
Expand Down
17 changes: 16 additions & 1 deletion allennlp/modules/transformer/transformer_pooler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from typing import Union, TYPE_CHECKING

import torch

from allennlp.common import FromParams
from allennlp.modules.transformer.activation_layer import ActivationLayer

if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig


class TransformerPooler(ActivationLayer, FromParams):

_pretrained_relevant_module = ["pooler", "bert.pooler", "roberta.pooler"]

def __init__(
self,
hidden_size: int,
intermediate_size: int,
activation: Union[str, torch.nn.Module] = "relu",
):
super().__init__(hidden_size, intermediate_size, "relu", pool=True)
super().__init__(hidden_size, intermediate_size, activation, pool=True)

@classmethod
def _from_config(cls, config: "PretrainedConfig", **kwargs):
return cls(config.hidden_size, config.hidden_size, "tanh") # BERT has this hardcoded
2 changes: 1 addition & 1 deletion allennlp/modules/transformer/transformer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TransformerStack(TransformerModule, FromParams):
"""

_pretrained_mapping = {"layer": "layers"}
_pretrained_relevant_module = ["encoder", "bert.encoder"]
_pretrained_relevant_module = ["encoder", "bert.encoder", "roberta.encoder"]

def __init__(
self,
Expand Down

0 comments on commit c8b8ed3

Please sign in to comment.