From cf113d705b9054d329c67cf9bb29cbc3f191015d Mon Sep 17 00:00:00 2001 From: Pete Date: Mon, 17 May 2021 12:25:42 -0700 Subject: [PATCH] Changes and improvements to how we initialize transformer modules from pretrained models (#5200) * updates * rename 'load_state_dict' -> 'read_state_dict' * fix TransformerStack * more fixes * fix embeddings * fix toolkit tests * fix self attention * fix bimodal encoder tests * fix more tests * fix T5! * fixes * fix backbone * fix * fixes * fix * doc fixes * name changes * patch models branch temporarily * update CHANGELOG * change default dist loading strategy to 'MEM_EFFICIENT' for T5 * fix distilbert test * always use memory efficient distributed loading strategy * Update .github/workflows/ci.yml Co-authored-by: Pete Co-authored-by: Akshita Bhagia --- CHANGELOG.md | 4 + allennlp/commands/diff.py | 6 +- allennlp/common/testing/distributed_test.py | 9 +- allennlp/common/util.py | 12 + allennlp/models/model.py | 2 +- .../modules/backbones/vilbert_backbone.py | 52 +- allennlp/modules/transformer/__init__.py | 2 +- .../modules/transformer/bimodal_attention.py | 5 +- .../transformer/bimodal_connection_layer.py | 2 +- .../modules/transformer/bimodal_encoder.py | 110 +--- allennlp/modules/transformer/layer_norm.py | 7 + allennlp/modules/transformer/output_layer.py | 5 +- .../transformer/positional_encoding.py | 3 + .../modules/transformer/self_attention.py | 70 +-- allennlp/modules/transformer/t5.py | 66 ++- .../transformer/transformer_embeddings.py | 66 ++- .../modules/transformer/transformer_layer.py | 76 +-- .../modules/transformer/transformer_module.py | 495 ++++++++++------ .../modules/transformer/transformer_stack.py | 85 +-- allennlp/nn/util.py | 217 ++++++- scripts/py2md.py | 7 + .../transformer/activation_layer_test.py | 38 +- .../transformer/bimodal_attention_test.py | 103 ++-- .../transformer/bimodal_encoder_test.py | 181 +++--- .../transformer/self_attention_test.py | 193 ++----- tests/modules/transformer/toolkit_test.py | 68 ++- .../transformer_embeddings_test.py | 539 +++++++++--------- .../transformer/transformer_layer_test.py | 529 ++++++++--------- .../transformer/transformer_module_test.py | 81 +-- .../transformer/transformer_stack_test.py | 253 +++----- tests/nn/util_test.py | 80 ++- 31 files changed, 1708 insertions(+), 1658 deletions(-) create mode 100644 allennlp/modules/transformer/layer_norm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a55610dea5e..92308cd9c2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `dist_reduce_sum` in distributed metrics. - Allow Google Cloud Storage paths in `cached_path` ("gs://..."). +- Renamed `nn.util.load_state_dict()` to `read_state_dict` to avoid confusion with `torch.nn.Module.load_state_dict()`. +- `TransformerModule.from_pretrained_module` now only accepts a pretrained model ID (e.g. "bert-base-case") instead of + an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). @@ -18,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `TaskSuite` base class and command line functionality for running [`checklist`](/~https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module. - Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files. +- Added `nn.util.distributed_device()` helper function. - Added `allennlp.nn.util.load_state_dict` helper function. - Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`. diff --git a/allennlp/commands/diff.py b/allennlp/commands/diff.py index 6d86f7db76f..35738ca2237 100644 --- a/allennlp/commands/diff.py +++ b/allennlp/commands/diff.py @@ -19,7 +19,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common.file_utils import cached_path -from allennlp.nn.util import load_state_dict +from allennlp.nn.util import read_state_dict logger = logging.getLogger(__name__) @@ -249,10 +249,10 @@ def _get_checkpoint_path(checkpoint: str) -> str: def _diff(args: argparse.Namespace): checkpoint_1_path = _get_checkpoint_path(args.checkpoint1) checkpoint_2_path = _get_checkpoint_path(args.checkpoint2) - checkpoint_1 = load_state_dict( + checkpoint_1 = read_state_dict( checkpoint_1_path, strip_prefix=args.strip_prefix_1, strict=False ) - checkpoint_2 = load_state_dict( + checkpoint_2 = read_state_dict( checkpoint_2_path, strip_prefix=args.strip_prefix_2, strict=False ) for step in checkpoint_diff(checkpoint_1, checkpoint_2, args.scale, args.threshold): diff --git a/allennlp/common/testing/distributed_test.py b/allennlp/common/testing/distributed_test.py index 7ef00e2e0e8..2fae00ff635 100644 --- a/allennlp/common/testing/distributed_test.py +++ b/allennlp/common/testing/distributed_test.py @@ -61,12 +61,19 @@ def run_distributed_test( func: `Callable` `func` needs to be global for spawning the processes, so that it can be pickled. + + start_method: `Optional[str]`, optional (default = `None`) + The start method to use for starting the workers. Defaults to "spawn" for GPU + processes and fork otherwise. """ device_ids = device_ids or [-1, -1] check_for_gpu(device_ids) # "fork" start method is the default and should be preferred, except when we're # running the tests on GPU, in which case we need to use "spawn". - start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" + if "start_method" in kwargs: + start_method = kwargs.pop("start_method") + else: + start_method = "spawn" if any(x >= 0 for x in device_ids) else "fork" nprocs = world_size = len(device_ids) mp.start_processes( init_process, diff --git a/allennlp/common/util.py b/allennlp/common/util.py index db77d795e8d..4db2ef6b5fe 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -509,6 +509,18 @@ def is_distributed() -> bool: return dist.is_available() and dist.is_initialized() +def is_global_primary() -> bool: + """ + Checks if the distributed process group is the global primary (rank = 0). + If the distributed process group is not available or has not been initialized, + this trivially returns `True`. + """ + if not is_distributed(): + return True + else: + return dist.get_rank() == 0 + + def sanitize_wordpiece(wordpiece: str) -> str: """ Sanitizes wordpieces from BERT, RoBERTa or ALBERT tokenizers. diff --git a/allennlp/models/model.py b/allennlp/models/model.py index 5ff7c967e8e..2800243a6a1 100644 --- a/allennlp/models/model.py +++ b/allennlp/models/model.py @@ -335,7 +335,7 @@ def _load( # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError # if the state dict is missing keys because we handle this case below. - model_state = util.load_state_dict(weights_file, cuda_device=cuda_device) + model_state = util.read_state_dict(weights_file, cuda_device=cuda_device) missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False) # Modules might define a class variable called `authorized_missing_keys`, diff --git a/allennlp/modules/backbones/vilbert_backbone.py b/allennlp/modules/backbones/vilbert_backbone.py index c1b9d1090b7..0f554a7a1d2 100644 --- a/allennlp/modules/backbones/vilbert_backbone.py +++ b/allennlp/modules/backbones/vilbert_backbone.py @@ -7,7 +7,12 @@ from allennlp.data.fields.text_field import TextFieldTensors from allennlp.data.vocabulary import Vocabulary from allennlp.modules.backbones.backbone import Backbone -from allennlp.modules.transformer import BiModalEncoder, ImageFeatureEmbeddings, Embeddings +from allennlp.modules.transformer import ( + BiModalEncoder, + ImageFeatureEmbeddings, + TransformerEmbeddings, + TransformerPooler, +) logger = logging.getLogger(__name__) @@ -23,7 +28,7 @@ class VilbertBackbone(Backbone): def __init__( self, vocab: Vocabulary, - text_embeddings: Embeddings, + text_embeddings: TransformerEmbeddings, image_embeddings: ImageFeatureEmbeddings, encoder: BiModalEncoder, pooled_output_dim: int, @@ -36,7 +41,6 @@ def __init__( self.text_embeddings = text_embeddings self.image_embeddings = image_embeddings self.encoder = encoder - from allennlp.modules.transformer import TransformerPooler self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim) self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim) @@ -66,44 +70,7 @@ def from_huggingface_model_name( image_fixed_layer: int, fusion_method: str = "sum", ): - from transformers import AutoModel - - transformer = AutoModel.from_pretrained(model_name) - - from copy import deepcopy - - text_embeddings = deepcopy(transformer.embeddings) - - # Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size". - # To get them to the same dimensionality, it uses a linear transform after the embedding - # layer, which we need to pull out and copy here. - if hasattr(transformer.config, "embedding_size"): - config = transformer.config - - from transformers.models.albert.modeling_albert import AlbertModel - - if isinstance(transformer, AlbertModel): - linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in) - else: - logger.warning( - "Unknown model that uses separate embedding size; weights of the linear " - f"transform will not be initialized. Model type is: {transformer.__class__}" - ) - linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim) - - # We can't just use torch.nn.Sequential here, even though that's basically all this is, - # because Sequential doesn't accept *inputs, only a single argument. - - class EmbeddingsShim(torch.nn.Module): - def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module): - super().__init__() - self.linear_transform = linear_transform - self.embeddings = embeddings - - def forward(self, *inputs, **kwargs): - return self.linear_transform(self.embeddings(*inputs, **kwargs)) - - text_embeddings = EmbeddingsShim(text_embeddings, linear_transform) + text_embeddings = TransformerEmbeddings.from_pretrained_module(model_name) image_embeddings = ImageFeatureEmbeddings( feature_size=image_feature_dim, @@ -112,7 +79,7 @@ def forward(self, *inputs, **kwargs): ) encoder = BiModalEncoder.from_pretrained_module( - pretrained_module=transformer, + model_name, num_hidden_layers2=image_num_hidden_layers, hidden_size2=image_hidden_size, num_attention_heads2=image_num_attention_heads, @@ -126,6 +93,7 @@ def forward(self, *inputs, **kwargs): fixed_layer1=text_fixed_layer, fixed_layer2=image_fixed_layer, ) + return cls( vocab=vocab, text_embeddings=text_embeddings, diff --git a/allennlp/modules/transformer/__init__.py b/allennlp/modules/transformer/__init__.py index b0b56b90d17..9b944130c7c 100644 --- a/allennlp/modules/transformer/__init__.py +++ b/allennlp/modules/transformer/__init__.py @@ -123,8 +123,8 @@ def forward(self, token_ids: torch.LongTensor, mask: torch.BoolTensor): ``` """ +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.positional_encoding import SinusoidalPositionalEncoding - from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.transformer_embeddings import ( Embeddings, diff --git a/allennlp/modules/transformer/bimodal_attention.py b/allennlp/modules/transformer/bimodal_attention.py index fc6bb4047f9..cc4bf11aa22 100644 --- a/allennlp/modules/transformer/bimodal_attention.py +++ b/allennlp/modules/transformer/bimodal_attention.py @@ -118,10 +118,12 @@ def forward( input_tensor2, attention_mask1=None, attention_mask2=None, - co_attention_mask=None, + co_attention_mask=None, # TODO: is this flag necessary? use_co_attention_mask=False, ): """ + # Parameters + input_tensor1 : `torch.Tensor` Shape `batch_size x seq_len1 x hidden_dim1` where `seq_len1` can be the sequence length @@ -143,7 +145,6 @@ def forward( if you know which words correspond to which regions in the image, this mask can be applied to limit the attention given the bias. use_co_attention_mask : `bool` - # TODO: is this flag necessary? Whether to use co_attention_mask or not, default = `False`. """ diff --git a/allennlp/modules/transformer/bimodal_connection_layer.py b/allennlp/modules/transformer/bimodal_connection_layer.py index 5d7e4f7fc88..f9656c2b7a5 100644 --- a/allennlp/modules/transformer/bimodal_connection_layer.py +++ b/allennlp/modules/transformer/bimodal_connection_layer.py @@ -31,7 +31,7 @@ def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2): class BiModalConnectionLayer(TransformerModule, FromParams): - _huggingface_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} + _pretrained_mapping = {"biAttention": "bimodal_attention", "biOutput": "bimodal_output"} def __init__( self, diff --git a/allennlp/modules/transformer/bimodal_encoder.py b/allennlp/modules/transformer/bimodal_encoder.py index bf5e732e96d..acc993194df 100644 --- a/allennlp/modules/transformer/bimodal_encoder.py +++ b/allennlp/modules/transformer/bimodal_encoder.py @@ -1,14 +1,16 @@ -from typing import Optional, Dict, List, Union +from typing import Optional, List, TYPE_CHECKING + import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers - from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.bimodal_connection_layer import BiModalConnectionLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class BiModalEncoder(TransformerModule, FromParams): """ @@ -46,8 +48,9 @@ class BiModalEncoder(TransformerModule, FromParams): in_batch_pairs: `bool` (default = `False`) """ - _huggingface_mapping = {"layer": "layers1"} - _relevant_module = "encoder" + _pretrained_mapping = {"layer": "layers1"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] + _pretrained_allow_missing = [r"^layers2\..*", r"^c_layer\..*"] def __init__( self, @@ -243,93 +246,14 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["num_hidden_layers1"] = len(submodules["layers1"]) - - final_kwargs["hidden_size1"] = submodules["layers1.0.attention.self.query"].in_features - final_kwargs["num_attention_heads1"] = submodules[ - "layers1.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout1"] = submodules["layers1.0.attention.self.dropout"].p - final_kwargs["hidden_dropout1"] = submodules["layers1.0.attention.output.dropout"].p - final_kwargs["intermediate_size1"] = submodules["layers1.0.intermediate.dense"].out_features - final_kwargs["activation"] = submodules["layers1.0.intermediate"].intermediate_act_fn - + final_kwargs["num_hidden_layers1"] = config.num_hidden_layers + final_kwargs["hidden_size1"] = config.hidden_size + final_kwargs["num_attention_heads1"] = config.num_attention_heads + final_kwargs["attention_dropout1"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout1"] = config.hidden_dropout_prob + final_kwargs["intermediate_size1"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act final_kwargs.update(**kwargs) - - return final_kwargs - - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): - if source == "huggingface": - ignore_absent_parameters = ["layers2", "c_layer"] - super()._load_from_pretrained_module( - pretrained_module, source, mapping, ignore_absent_parameters - ) - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers2: int, - hidden_size2: int, - combined_hidden_size: int, - intermediate_size2: int, - num_attention_heads2: int, - combined_num_attention_heads: int, - attention_dropout2: float, - hidden_dropout2: float, - biattention_id1: List[int], - biattention_id2: List[int], - fixed_layer1: int, - fixed_layer2: int, - fast_mode: bool = False, - with_coattention: bool = True, - in_batch_pairs: bool = False, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - # **kwargs, - ): - """ - The `pretrained_module` only supplies one of the modalities. - """ - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping - ) - final_kwargs = {} - final_kwargs.update(cls._get_input_arguments(pretrained_module, source, mapping)) - final_kwargs["num_hidden_layers2"] = num_hidden_layers2 - final_kwargs["hidden_size2"] = hidden_size2 - final_kwargs["combined_hidden_size"] = combined_hidden_size - final_kwargs["intermediate_size2"] = intermediate_size2 - final_kwargs["num_attention_heads2"] = num_attention_heads2 - final_kwargs["combined_num_attention_heads"] = combined_num_attention_heads - final_kwargs["attention_dropout2"] = attention_dropout2 - final_kwargs["hidden_dropout2"] = hidden_dropout2 - final_kwargs["biattention_id1"] = biattention_id1 - final_kwargs["biattention_id2"] = biattention_id2 - final_kwargs["fixed_layer1"] = fixed_layer1 - final_kwargs["fixed_layer2"] = fixed_layer2 - final_kwargs["fast_mode"] = fast_mode - final_kwargs["with_coattention"] = with_coattention - final_kwargs["in_batch_pairs"] = in_batch_pairs - - return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs) + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/layer_norm.py b/allennlp/modules/transformer/layer_norm.py new file mode 100644 index 00000000000..0302b705c1d --- /dev/null +++ b/allennlp/modules/transformer/layer_norm.py @@ -0,0 +1,7 @@ +import torch + +from allennlp.modules.transformer.transformer_module import TransformerModule + + +class LayerNorm(torch.nn.LayerNorm, TransformerModule): + _pretrained_mapping = {"gamma": "weight", "beta": "bias"} diff --git a/allennlp/modules/transformer/output_layer.py b/allennlp/modules/transformer/output_layer.py index 03dd1f9d5df..ac38a1794b1 100644 --- a/allennlp/modules/transformer/output_layer.py +++ b/allennlp/modules/transformer/output_layer.py @@ -3,16 +3,17 @@ from allennlp.common import FromParams from allennlp.modules.transformer.transformer_module import TransformerModule +from allennlp.modules.transformer.layer_norm import LayerNorm class OutputLayer(TransformerModule, FromParams): - _huggingface_mapping = {"LayerNorm": "layer_norm"} + _pretrained_mapping = {"LayerNorm": "layer_norm"} def __init__(self, input_size: int, hidden_size: int, dropout: float): super().__init__() self.dense = torch.nn.Linear(input_size, hidden_size) - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.layer_norm = LayerNorm(hidden_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, hidden_states, input_tensor): diff --git a/allennlp/modules/transformer/positional_encoding.py b/allennlp/modules/transformer/positional_encoding.py index 1cf63b15c91..b0abc2b91b2 100644 --- a/allennlp/modules/transformer/positional_encoding.py +++ b/allennlp/modules/transformer/positional_encoding.py @@ -42,6 +42,9 @@ def __init__(self, min_timescale: float = 1.0, max_timescale: float = 1.0e4): self.max_timescale = max_timescale def forward(self, input_tensor: torch.Tensor): + """ + Adds a positional encoding to `input_tensor`. + """ # TODO: Another option is to specify the expected size in init, so that we can construct # the positional encoding beforehand, and simply add it to the input tensor in forward. _, timesteps, hidden_dim = input_tensor.size() diff --git a/allennlp/modules/transformer/self_attention.py b/allennlp/modules/transformer/self_attention.py index 6db6aba1fad..d464012de81 100644 --- a/allennlp/modules/transformer/self_attention.py +++ b/allennlp/modules/transformer/self_attention.py @@ -1,4 +1,5 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING + import torch from allennlp.common import FromParams @@ -6,6 +7,9 @@ from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class SelfAttention(TransformerModule, FromParams): """ @@ -25,8 +29,15 @@ class SelfAttention(TransformerModule, FromParams): Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. """ - _relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] - _huggingface_mapping = {"layer": "layers"} + _pretrained_relevant_module = ["encoder.layers.0.attention.self", "encoder.layers.0.attention"] + _pretrained_mapping = { + "layer": "layers", + "q_lin": "query", + "k_lin": "key", + "v_lin": "value", + "out_lin": "output", + "transformer": "encoder", + } def __init__( self, @@ -83,6 +94,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + query_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` key_states : `torch.Tensor`, optional @@ -133,47 +146,16 @@ def forward( return outputs @classmethod - def _get_mapping( - cls, pretrained_module=None, source="huggingface", mapping: Optional[Dict[str, str]] = None - ): - combined_mapping = {} - if "huggingface" in source: - combined_mapping.update(cls._huggingface_mapping) - if mapping is not None: - combined_mapping.update(mapping) - if pretrained_module is not None: - for name, _ in pretrained_module.named_modules(): - if "q_lin" in name: - combined_mapping["q_lin"] = "query" - combined_mapping["k_lin"] = "key" - combined_mapping["v_lin"] = "value" - combined_mapping["out_lin"] = "output" - combined_mapping["transformer"] = "encoder" - break - return combined_mapping - - @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["query"].in_features - if hasattr(submodules[""], "num_attention_heads"): - final_kwargs["num_attention_heads"] = submodules[""].num_attention_heads - elif hasattr(submodules[""], "n_heads"): - final_kwargs["num_attention_heads"] = submodules[""].n_heads - final_kwargs["output_linear"] = True # Since this is the distilbert case. + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["output_linear"] = hasattr( + config, "n_heads" + ) # Since this is the distilbert case. + if hasattr(config, "attention_dropout"): + final_kwargs["dropout"] = config.attention_dropout else: - raise AttributeError("Cannot find a relevant attribute for number of heads.") - - final_kwargs["dropout"] = submodules["dropout"].p - + final_kwargs["dropout"] = config.attention_probs_dropout_prob final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/t5.py b/allennlp/modules/transformer/t5.py index 83305487b76..15d34f5b2b1 100644 --- a/allennlp/modules/transformer/t5.py +++ b/allennlp/modules/transformer/t5.py @@ -1,11 +1,11 @@ """ -Adapted from [HuggingFace] +An implementation of [T5](https://api.semanticscholar.org/CorpusID:204838007), adapted from [HuggingFace] (/~https://github.com/huggingface/transformers/blob/4c32f9f26e6a84f0d9843fec8757e6ce640bb44e/src/transformers/models/t5/modeling_t5.py). """ # noqa: E401 import math from dataclasses import dataclass -from typing import Optional, Tuple, List, Union, Dict, Any +from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING import torch from torch import nn @@ -14,13 +14,18 @@ from allennlp.common import FromParams, Params, Lazy, Registrable from allennlp.common.checks import ConfigurationError -from allennlp.modules.transformer import TransformerModule +from allennlp.modules.transformer.transformer_module import ( + TransformerModule, +) from allennlp.modules.transformer.util import ( apply_mask, get_extended_attention_mask, ) from allennlp.nn.beam_search import BeamSearch +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + # Unfortunately mypy is insane, so I have to wrap these in unions. FloatT = Union[torch.FloatTensor] IntT = Union[torch.IntTensor] @@ -94,7 +99,7 @@ def forward(self, hidden_states) -> FloatT: class T5LayerFF(TransformerModule, FromParams): - _huggingface_mapping = {"DenseReluDense": "ff_proj"} + _pretrained_mapping = {"DenseReluDense": "ff_proj"} def __init__( self, @@ -376,16 +381,19 @@ class T5LayerSelfAttentionOutput: class T5LayerSelfAttention(TransformerModule, FromParams): - _huggingface_mapping = {"SelfAttention": "self_attention"} + _pretrained_mapping = {"SelfAttention": "self_attention"} def __init__( self, self_attention: Optional[T5Attention] = None, layer_norm: Optional[T5LayerNorm] = None, dropout: float = 0.1, + has_relative_attention_bias: bool = False, ): super().__init__() - self.self_attention = self_attention or T5Attention() + self.self_attention = self_attention or T5Attention( + has_relative_attention_bias=has_relative_attention_bias + ) self.layer_norm = layer_norm or T5LayerNorm(hidden_size=self.self_attention.hidden_size) self.dropout = nn.Dropout(dropout) @@ -427,7 +435,7 @@ class T5LayerCrossAttentionOutput: class T5LayerCrossAttention(TransformerModule, FromParams): - _huggingface_mapping = {"EncDecAttention": "enc_dec_attention"} + _pretrained_mapping = {"EncDecAttention": "enc_dec_attention"} def __init__( self, @@ -618,7 +626,7 @@ class T5StackOutput: class T5Stack(TransformerModule, FromParams): - _huggingface_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} + _pretrained_mapping = {"embed_tokens": "token_embeddings", "block": "blocks"} def __init__( self, @@ -959,7 +967,18 @@ class T5Output: class T5(TransformerModule, Registrable): - _huggingface_mapping = {"shared": "token_embeddings"} + _pretrained_mapping = {"shared": "token_embeddings"} + _tied_weights = { + "token_embeddings.weight": [ + "encoder.token_embeddings.weight", + "decoder.token_embeddings.weight", + "lm_head.weight", + ] + } + # Don't know why HF has this param in their state_dict. It's not used in their model. + _pretrained_ignore = [ + r"^decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight$" + ] default_implementation = "default" @@ -1003,16 +1022,7 @@ def __init__( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: - from transformers.models.t5 import T5Config - - config: T5Config = pretrained_module.config + def _from_config(cls, config: "PretrainedConfig", **kwargs): attention_kwargs = { "hidden_size": config.d_model, "key_value_proj_dim": config.d_kv, @@ -1039,8 +1049,8 @@ def _get_input_arguments( } ), ) - return { - "encoder": Lazy( + return cls( + encoder=Lazy( T5EncoderStack.basic_encoder, contructor_extras={ "num_blocks": config.num_layers, @@ -1050,7 +1060,7 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder": Lazy( + decoder=Lazy( T5DecoderStack.basic_decoder, contructor_extras={ "num_blocks": config.num_decoder_layers, @@ -1061,12 +1071,12 @@ def _get_input_arguments( "dropout": config.dropout_rate, }, ), - "decoder_start_token_id": config.decoder_start_token_id, - "pad_token_id": config.pad_token_id, - "eos_token_id": config.eos_token_id, - "vocab_size": config.vocab_size, - "model_dim": config.d_model, - } + decoder_start_token_id=config.decoder_start_token_id, + pad_token_id=config.pad_token_id, + eos_token_id=config.eos_token_id, + vocab_size=config.vocab_size, + model_dim=config.d_model, + ) def _shift_right(self, input_ids, start_value: int): # shift inputs to the right diff --git a/allennlp/modules/transformer/transformer_embeddings.py b/allennlp/modules/transformer/transformer_embeddings.py index 754344d1c0e..3712d9b0a3a 100644 --- a/allennlp/modules/transformer/transformer_embeddings.py +++ b/allennlp/modules/transformer/transformer_embeddings.py @@ -1,11 +1,14 @@ -from typing import Optional, Dict +from typing import Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - +from allennlp.modules.transformer.layer_norm import LayerNorm from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class Embeddings(TransformerModule, FromParams): """ @@ -38,7 +41,7 @@ def __init__(self, embeddings: torch.nn.ModuleDict, embedding_size: int, dropout ) ) self.embeddings = embeddings - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.layer_norm = LayerNorm(embedding_size, eps=1e-12) self.dropout = torch.nn.Dropout(dropout) def forward(self, *inputs) -> torch.Tensor: @@ -101,13 +104,27 @@ class TransformerEmbeddings(Embeddings): Optionally apply a linear transform after the dropout, projecting to `output_size`. """ - _relevant_module = "embeddings" - _huggingface_mapping = { + _pretrained_relevant_module = ["embeddings", "bert.embeddings"] + _pretrained_mapping = { "LayerNorm": "layer_norm", "word_embeddings": "embeddings.word_embeddings", "position_embeddings": "embeddings.position_embeddings", "token_type_embeddings": "embeddings.token_type_embeddings", + # Albert is a special case. A linear projection is applied to the embeddings, + # but that linear transformation lives in the encoder. + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.LayerNorm": "layer_norm", + "albert.embeddings.word_embeddings": "embeddings.word_embeddings", + "albert.embeddings.position_embeddings": "embeddings.position_embeddings", + "albert.embeddings.token_type_embeddings": "embeddings.token_type_embeddings", + "albert.encoder.embedding_hidden_mapping_in": "linear_transform", } + _pretrained_ignore = [ + # Ignore these for Albert case. + r"^albert\.pooler\..*", + r"^albert\.encoder\.albert_layer_groups\..*", + r"^predictions\.*", + ] def __init__( self, @@ -149,6 +166,7 @@ def forward( # type: ignore ) -> torch.Tensor: """ + # Parameters input_ids : `torch.Tensor` Shape `batch_size x seq_len` token_type_ids : `torch.Tensor`, optional @@ -182,32 +200,18 @@ def forward( # type: ignore return embeddings @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["vocab_size"] = submodules["embeddings.word_embeddings"].num_embeddings - final_kwargs["embedding_size"] = submodules["embeddings.word_embeddings"].embedding_dim - final_kwargs["pad_token_id"] = submodules["embeddings.word_embeddings"].padding_idx - final_kwargs["max_position_embeddings"] = submodules[ - "embeddings.position_embeddings" - ].num_embeddings - - if "embeddings.token_type_embeddings" in submodules: - final_kwargs["type_vocab_size"] = submodules[ - "embeddings.token_type_embeddings" - ].num_embeddings - + final_kwargs["vocab_size"] = config.vocab_size + # For Albert, the embedding size is different than the hidden size used + # in the model, so a linear transform is applied. + if hasattr(config, "embedding_size"): + final_kwargs["embedding_size"] = config.embedding_size + final_kwargs["output_size"] = config.hidden_size else: - final_kwargs["type_vocab_size"] = 0 - + final_kwargs["embedding_size"] = config.hidden_size + final_kwargs["pad_token_id"] = config.pad_token_id + final_kwargs["max_position_embeddings"] = config.max_position_embeddings + final_kwargs["type_vocab_size"] = config.type_vocab_size final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_layer.py b/allennlp/modules/transformer/transformer_layer.py index 3282b2dbf14..43a76d33144 100644 --- a/allennlp/modules/transformer/transformer_layer.py +++ b/allennlp/modules/transformer/transformer_layer.py @@ -1,15 +1,16 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import torch from allennlp.common import FromParams - from allennlp.modules.transformer.transformer_module import TransformerModule - from allennlp.modules.transformer.activation_layer import ActivationLayer from allennlp.modules.transformer.self_attention import SelfAttention from allennlp.modules.transformer.output_layer import OutputLayer +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + class AttentionLayer(TransformerModule, FromParams): """ @@ -28,8 +29,8 @@ class AttentionLayer(TransformerModule, FromParams): Dropout probability for the `OutputLayer`. """ - _relevant_module = "encoder.layers.0.attention" - _huggingface_mapping = {"layer": "layers"} + _pretrained_relevant_module = "encoder.layer.0.attention" + _pretrained_mapping = {"layer": "layers"} def __init__( self, @@ -52,6 +53,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + input_tensor : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -77,25 +80,16 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - final_kwargs["hidden_size"] = submodules["self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["output.dropout"].p + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) class TransformerLayer(TransformerModule, FromParams): @@ -120,8 +114,8 @@ class TransformerLayer(TransformerModule, FromParams): This is helpful when using the layer in a decoder. """ - _relevant_module = "encoder.layers.0" - _huggingface_mapping = { + _pretrained_relevant_module = "encoder.layer.0" + _pretrained_mapping = { "layer": "layers", "intermediate_act_fn": "act_fn", "crossattention": "cross_attention", @@ -174,6 +168,8 @@ def forward( output_attentions: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -218,32 +214,14 @@ def forward( return outputs @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["hidden_size"] = submodules["attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules["attention.self"].num_attention_heads - final_kwargs["attention_dropout"] = submodules["attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr(submodules["intermediate"], "intermediate_act_fn") - else: - final_kwargs["activation"] = getattr(submodules["intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "cross_attention" in submodules - + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act + final_kwargs["add_cross_attention"] = config.add_cross_attention final_kwargs.update(**kwargs) - - return final_kwargs + return cls(**final_kwargs) diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index 861120deca2..2a0ffa092ce 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -1,229 +1,382 @@ -from typing import Optional, Dict, Union, List, Any import logging -import inspect +import os +from os import PathLike +from typing import TYPE_CHECKING, Optional, Dict, Union, List, Any, TypeVar, Type +import re +import warnings import torch +import torch.distributed as dist + +from allennlp.common.util import is_distributed, is_global_primary +from allennlp.nn.util import StateDictType, read_state_dict, load_state_dict_distributed + +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig -from allennlp.common import cached_transformers logger = logging.getLogger(__name__) +_T = TypeVar("_T", bound="TransformerModule") + + class TransformerModule(torch.nn.Module): """ Base class to help with generalized loading of pretrained weights. - `_huggingface_mapping` is an optional mapping for each class, that determines - any differences in the module names between the class modules and the huggingface model's - modules. + Subclasses should override `_from_config()` if you want to instantiate them with + `from_pretrained_module()`. + """ + + _pretrained_mapping: Dict[str, str] = {} + """ + An optional mapping for each class that determines any differences in the module + names between the class modules and the HuggingFace model's modules. + Keys correspond to HuggingFace submodule names, values correspond to submodules names of this module. + """ - `_relevant_module` is an optional str or list of str which contains the expected name of the module - in the huggingface pretrained model. It can be a list to account for different names in different + _pretrained_relevant_module: Optional[Union[str, List[str]]] = None + """ + An optional string or list of strings which contains the expected name of the module + in the HuggingFace pretrained model. It can be a list to account for different names in different models. The search is carried out in the order of the list. """ - _huggingface_mapping: Dict[str, str] = {} - _relevant_module: Optional[Union[str, List[str]]] = None + _pretrained_ignore: Optional[List[str]] = None + """ + An optional list of regular expressions that define which weights to ignore from a pretrained state_dict. + """ + + _pretrained_allow_missing: Optional[List[str]] = None + """ + An optional list of regular expressions that specifies which weights are allowed to be missing + from a pretrained state dictionary. + """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + _tied_weights: Optional[Dict[str, List[str]]] = None + """ + A mapping that defines any weights that need to be tied. Keys and values are parameter names. + The values will be tied to the corresponding key. + """ @classmethod def _get_mapping( cls, - pretrained_module: Optional[torch.nn.Module] = None, - source: str = "huggingface", mapping: Optional[Dict[str, str]] = None, ): """ - Returns the mapping to be used, based on the optional `pretrained_module`. - If `pretrained_module` is not given, the default module-level mapping is returned. + Returns the mapping to be used, based on the optional `mapping` overrides + and the default module-level mapping. """ combined_mapping = {} - if "huggingface" == source: - combined_mapping.update(cls._huggingface_mapping) + combined_mapping.update(cls._pretrained_mapping) if mapping is not None: combined_mapping.update(mapping) return combined_mapping - @classmethod - def _get_mapped_submodules( - cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - ): - """ - Subclasses overload this method, and provide appropriate name mapping based on the source. - """ - submodules = dict(pretrained_module.named_modules()) - combined_mapping = cls._get_mapping(pretrained_module, source, mapping) - for name, module in pretrained_module.named_modules(): - newname = name - for key, val in combined_mapping.items(): - newname = newname.replace(key, val) - submodules[newname] = submodules.pop(name) - return submodules - - def _construct_default_mapping( + def _get_mapped_state_dict( self, - pretrained_module: torch.nn.Module, - source: str = "huggingface", + state_dict: StateDictType, mapping: Optional[Dict[str, str]] = None, - ): + ) -> StateDictType: """ - Recursively constructs the default mapping of parameter names for loading pretrained module weights. - Keys are parameter names from this module, and values are corresponding parameter names in the - expected pretrained module, as per `source`. + Recursively map keys in a HuggingFace `state_dict` to the corresponding keys + for this module and all submodules. """ - combined_mapping = self._get_mapping(pretrained_module, source, mapping) - for name, module in self.named_modules(): - if name != "": - if hasattr(module, "_construct_default_mapping"): - # We handle collisions by giving priority to the outer module's mapping. - combined_mapping = dict( - list( - module._construct_default_mapping( - pretrained_module, source, combined_mapping - ).items() - ) - + list(combined_mapping.items()) - ) - return combined_mapping + return _get_mapped_state_dict(self, state_dict, mapping=mapping) - def _load_from_pretrained_module( - self, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - ignore_absent_parameters: Optional[List] = None, - ): + @classmethod + def _get_relevant_submodule_state( + cls, + state_dict: StateDictType, + relevant_module: Optional[Union[str, List[str]]] = None, + ) -> StateDictType: """ - Loads the weights of the `pretrained_module` into the instance. - Optionally, a `mapping` is specified for any differences in parameter names - between `pretrained_module` and the instance. + Returns the relevant part of the `state_dict`. """ - ignore_absent_parameters = ignore_absent_parameters or [] - combined_mapping = self._construct_default_mapping(pretrained_module, source, mapping) - if mapping is not None: - combined_mapping.update(mapping) + relevant_modules: Optional[List[str]] = None + if relevant_module: + relevant_modules = ( + [relevant_module] if isinstance(relevant_module, str) else relevant_module + ) + elif isinstance(cls._pretrained_relevant_module, str): + relevant_modules = [cls._pretrained_relevant_module] + elif isinstance(cls._pretrained_relevant_module, list): + relevant_modules = cls._pretrained_relevant_module - inverse_mapping = {val: key for key, val in combined_mapping.items()} - pretrained_parameters = dict(pretrained_module.named_parameters()) - for name, parameter in self.named_parameters(): - pretrained_name = name - for key, val in inverse_mapping.items(): - # so that we replace the names of submodules too. - # eg. module.key.anothermodule --> module.val.anothermodule - pretrained_name = pretrained_name.replace(key, val) - if not any( - [pretrained_name.startswith(paraname) for paraname in ignore_absent_parameters] - ): - if pretrained_name not in pretrained_parameters: - raise ValueError( - f"Couldn't find a matching parameter for {name}. Is this module " - "compatible with the pretrained module you're using?" - ) - parameter.data.copy_(pretrained_parameters[pretrained_name].data) + if relevant_modules: + found = False + for module_name in relevant_modules: + relevant_keys = set( + [key for key in state_dict.keys() if key.startswith(module_name + ".")] + ) + if relevant_keys: + # Only keep elements of state dict that correspond to the relevant module. + state_dict = { + key.replace(module_name + ".", "", 1): value + for key, value in state_dict.items() + if key in relevant_keys + } + found = True + break + + if not found: + warnings.warn( + f"{relevant_modules} was not found at top level of state_dict!", UserWarning + ) + + return state_dict @classmethod - def _get_input_arguments( + def _get_pretrained_state_dict( cls, - pretrained_module: torch.nn.Module, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: + model_name: str, + weights_path: Optional[Union[str, PathLike]] = None, + relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, + ) -> StateDictType: """ - Constructs the arguments required for instantiating an object of this class, using - the values from `pretrained_module`. + Get a HuggingFace pretrained `state_dict` corresponding to this module. """ - return kwargs + if weights_path is None: + from transformers.file_utils import WEIGHTS_NAME + + # First see if we can find the weights locally. + if os.path.isdir(model_name): + local_weights_path = os.path.join(model_name, WEIGHTS_NAME) + if os.path.isfile(local_weights_path): + logger.info("Found weights at local path %s", local_weights_path) + weights_path = local_weights_path + + # If we haven't found locally, we assume model ID corresponds to a model + # on the HuggingFace Hub. + if weights_path is None: + from allennlp.common.file_utils import cached_path + + weights_path = cached_path(f"hf://{model_name}/{WEIGHTS_NAME}") + + # Now load the state dict. + logger.info("Reading state dict from %s", weights_path) + state_dict = read_state_dict( + weights_path, + ignore=ignore if ignore is not None else cls._pretrained_ignore, + strict=False, + ) + + # Keep just the relevant_module, remove everything else. + state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module) + + return state_dict @classmethod - def get_relevant_module( - cls, - pretrained_module: Union[str, torch.nn.Module], - relevant_module: Optional[Union[str, List[str]]] = None, - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, + def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T: + """ + Instantiate this module from a HuggingFace config. Subclasses should override + this method if you want to be able to instantiate them with `from_pretrained_module()`. + """ + raise NotImplementedError + + def tie_weights(self) -> None: + """ + Tie weights according to the `_tied_weights` class attribute. + + This should always be called after loading a state dictionary. It will be called + automatically within `from_pretrained_module()`. + """ + if self._tied_weights: + param_dict = dict(self.named_parameters()) + param_dict.update(dict(self.named_buffers())) + for anchor_name, free_names in self._tied_weights.items(): + for free_name in free_names: + param_dict[free_name] = param_dict[anchor_name] + + @classmethod + def from_pretrained_module( + cls: Type[_T], + model_name: str, + *, load_weights: bool = True, - ): + weights_path: Optional[Union[str, PathLike]] = None, + auto_config_kwargs: Optional[Dict[str, Any]] = None, + mapping: Optional[Dict[str, str]] = None, + relevant_module: Optional[Union[str, List[str]]] = None, + ignore: Optional[List[str]] = None, + allow_missing: Optional[List[str]] = None, + strict: bool = True, + **kwargs, + ) -> _T: """ - Returns the relevant underlying module given a model name/object. + Initialize this module from a corresponding model on HuggingFace. + + !!! Note + This method is only available for subclasses that implement `_from_config()`. + Otherwise a `NotImplementedError` will be raised. # Parameters - pretrained_module : `Union[str, torch.nn.Module]` - Name of the transformer model containing the layer, - or the actual layer (not the model object). - relevant_module : `Optional[Union[str, List[str]]]`, optional - Name of the desired module. Defaults to cls._relevant_module. - source : `str`, optional - Where the model came from. Default - huggingface. - mapping : `Dict[str, str]`, optional - Optional mapping that determines any differences in the module names - between the class modules and the input model's modules. - Default - cls._huggingface_mapping - load_weights : `bool`, optional - Whether or not to load the pretrained weights. - Default is `True`. - """ - if isinstance(pretrained_module, str): - pretrained_module = cached_transformers.get( - pretrained_module, False, load_weights=load_weights - ) + model_name : `str` + The model identifier or path. - relevant_module = relevant_module or cls._relevant_module + load_weights : `bool`, optional (default = `True`) + Whether to download and load the pretrained weights. If `False`, the + weights are left uninitialized. - if relevant_module is not None: - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - # If the relevant_module is not found, we assume that the pretrained_module - # is already the relevant module. - if isinstance(relevant_module, str): - relevant_module = [relevant_module] - found = False - for module in relevant_module: - if module in submodules: - pretrained_module = submodules[module] - found = True - break + weights_path : `Optional[Union[str, PathLike]]`, optional (default = `None`) + When `load_weights` is `True`, this can be set to override the weights file. + Otherwise the default weights from the pretrained model are used. - if not found: - logger.warning( - "{} was not found! The submodules are: {}".format( - relevant_module, submodules.keys() + auto_config_kwargs : `Optional[Dict[str, Any]]`, optional (default = `None`) + Optional key-word arguments to pass to `transformers.AutoConfig.from_pretrained()` + to load the pretrained model's configuration file. + + mapping : `Optional[Dict[str, str]]`, optional (default = `None`) + Optional mapping that determines any differences in the submodule names + between this module and the pretrained model from HuggingFace. + If not given, the class's default is used: `cls._pretrained_mapping`. + + relevant_module : `Optional[str]`, optional (default = `None`) + An optional submodule of the HuggingFace module to initialize weights from. + This is only relevant when `load_weights` is `True`. + If not given, the class's default is used: `cls._pretrained_relevant_module`. + + ignore : `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that define which weights to ignore + from a pretrained state_dict. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_ignore`. + + allow_missing: `Optional[List[str]]`, optional (default = `None`) + An optional list of regular expressions that specifies which weights are allowed to be missing + from the pretrained state dictionary. + This is only relevant when `load_weights` is `True`. + If not specified, the class's default is used: `cls._pretrained_allow_missing`. + + strict : `bool`, optional (default = `True`) + Whether to load the `state_dict` in "strict" model. This only applies + when `load_weights` is `True`. + + **kwargs : `Any` + Key word arguments to pass to `cls.from_config()` when instantiating the module. + """ # noqa: E501 + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_name, **(auto_config_kwargs or {})) + model = cls._from_config(config, **kwargs) + + if load_weights: + state_dict: Optional[StateDictType] = None + if is_global_primary(): + # Load the pretrained HuggingFace state_dict. + pretrained_state_dict = cls._get_pretrained_state_dict( + model_name, + weights_path=weights_path, + relevant_module=relevant_module, + ignore=ignore, + ) + # Now map keys from the HuggingFace state_dict to the corresponding keys from + # this class. This is called recursively on each submodule of the current module. + state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping) + + missing_keys: List[str] + unexpected_keys: List[str] + error_msgs: List[str] = [] + if not is_distributed(): + assert state_dict is not None + logger.info("Loading state_dict into module") + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + else: + # We're in distributed training. `state_dict` is `None` for all process groups + # except the global primary. + # Syncronize here since non-primary process groups will have to wait for the primary + # to load the state_dict into memory. + dist.barrier() + # Now load the state dict into the model. + logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)") + missing_keys, unexpected_keys = load_state_dict_distributed( + model, state_dict, strict=False + ) + + # Exclude any keys in `missing_keys` that match with the `allow_missing` + # regular expressions. + if allow_missing is None: + allow_missing = cls._pretrained_allow_missing + if allow_missing: + missing_keys = [ + k for k in missing_keys if not any(re.match(p, k) for p in allow_missing) + ] + + # Allow missing keys in state_dict for params that are going to be tied. + for param_names in (model._tied_weights or {}).values(): + for param_name in param_names: + if param_name in missing_keys: + missing_keys.remove(param_name) + + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in missing_keys) + ) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) ) ) - return pretrained_module - @classmethod - def from_pretrained_module( - cls, - pretrained_module: Union[str, torch.nn.Module], - source: str = "huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - **kwargs, - ): - """ - Creates and returns an instance of the class, by using the weights - (and the architecture, by default) of the `pretrained_module`. - Optionally, the architecture can be changed by providing arguments. - """ - accepted_args = inspect.getfullargspec(cls).args - accepted_args.remove("self") - for key in kwargs: - assert key in accepted_args, ( - "{} is not a valid argument for creating an instance of `{}`. " - "Accepted arguments are {}.".format(key, cls.__name__, accepted_args) - ) + if error_msgs and strict: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + cls.__name__, "\n\t".join(error_msgs) + ) + ) + + # If there were error messages but we're not loading in 'strict' mode, + # we just issue warnings from the logger. + for msg in error_msgs: + logger.warning(msg) - pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping, load_weights=load_weights + model.tie_weights() + + return model + + +def _get_mapped_state_dict( + module: torch.nn.Module, + state_dict: StateDictType, + mapping: Optional[Dict[str, str]] = None, +) -> StateDictType: + # First fix all top-level keys according to `combined_mapping`. + combined_mapping = module._get_mapping(mapping) if isinstance(module, TransformerModule) else {} + for hf_key, cls_key in sorted( + # Sort by most specific key first. + combined_mapping.items(), + key=lambda x: x[0].count("."), + reverse=True, + ): + relevant_keys = set( + [key for key in state_dict.keys() if (key == hf_key or key.startswith(hf_key + "."))] ) - final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping) - final_kwargs.update(kwargs) - module = cls(**final_kwargs) - module._load_from_pretrained_module(pretrained_module, source, mapping) - return module + for key in relevant_keys: + new_key = key.replace(hf_key, cls_key, 1) + # We have to be careful not to overwrite an entry that we might have updated + # on a previous iteration of this loop due to having a more specific key. + if new_key not in state_dict: + state_dict[new_key] = state_dict.pop(key) + + # Now loop through the submodules, calling this function on each submodule. + for name, submodule in module.named_children(): + # Pull-out the part of the state_dict corresponding to just this submodule. + relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")]) + module_state_dict = { + key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys + } + # Recursively call this function from the submodule to map this part + # of the state_dict. + module_state_dict = _get_mapped_state_dict(submodule, module_state_dict) + # And then update the full state_dict. + for key, value in module_state_dict.items(): + state_dict[name + "." + key] = value + + return state_dict diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index 09fb1d2bc40..7bc4a7247d3 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -1,14 +1,17 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional, TYPE_CHECKING import logging import torch from allennlp.common import FromParams - from allennlp.modules.util import replicate_layers from allennlp.modules.transformer.transformer_layer import TransformerLayer from allennlp.modules.transformer.transformer_module import TransformerModule +if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + + logger = logging.getLogger(__name__) @@ -38,8 +41,8 @@ class TransformerStack(TransformerModule, FromParams): This is helpful when using the `TransformerStack` as a decoder. """ - _huggingface_mapping = {"layer": "layers"} - _relevant_module = "encoder" + _pretrained_mapping = {"layer": "layers"} + _pretrained_relevant_module = ["encoder", "bert.encoder"] def __init__( self, @@ -86,6 +89,8 @@ def forward( output_hidden_states: bool = False, ): """ + # Parameters + hidden_states : `torch.Tensor` Shape `batch_size x seq_len x hidden_dim` attention_mask : `torch.BoolTensor`, optional @@ -129,67 +134,15 @@ def forward( ) @classmethod - def _get_input_arguments( - cls, - pretrained_module: torch.nn.Module, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - **kwargs, - ): - submodules = cls._get_mapped_submodules(pretrained_module, source, mapping) - + def _from_config(cls, config: "PretrainedConfig", **kwargs): final_kwargs = {} - - final_kwargs["num_hidden_layers"] = len(submodules["layers"]) - - final_kwargs["hidden_size"] = submodules["layers.0.attention.self.query"].in_features - final_kwargs["num_attention_heads"] = submodules[ - "layers.0.attention.self" - ].num_attention_heads - final_kwargs["attention_dropout"] = submodules["layers.0.attention.self.dropout"].p - final_kwargs["hidden_dropout"] = submodules["layers.0.attention.output.dropout"].p - final_kwargs["intermediate_size"] = submodules["layers.0.intermediate.dense"].out_features - - # We require the if block as `act_fn` is a function rather than a module, - # so `_get_mapped_submodules` does not automatically fix this. - if source == "huggingface": - final_kwargs["activation"] = getattr( - submodules["layers.0.intermediate"], "intermediate_act_fn" - ) - else: - final_kwargs["activation"] = getattr(submodules["layers.0.intermediate"], "act_fn") - - final_kwargs["add_cross_attention"] = "layers.0.cross_attention" in submodules - + final_kwargs["num_hidden_layers"] = config.num_hidden_layers + final_kwargs["hidden_size"] = config.hidden_size + final_kwargs["num_attention_heads"] = config.num_attention_heads + final_kwargs["add_cross_attention"] = config.add_cross_attention + final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob + final_kwargs["hidden_dropout"] = config.hidden_dropout_prob + final_kwargs["intermediate_size"] = config.intermediate_size + final_kwargs["activation"] = config.hidden_act final_kwargs.update(**kwargs) - - return final_kwargs - - @classmethod - def from_pretrained_module( # type: ignore - cls, - pretrained_module: Union[str, torch.nn.Module], - num_hidden_layers: Optional[Union[int, range]] = None, - source="huggingface", - mapping: Optional[Dict[str, str]] = None, - load_weights: bool = True, - **kwargs, - ): - final_kwargs = {} - if num_hidden_layers is not None: - if isinstance(num_hidden_layers, range): - if mapping is None: - mapping = {} - for num_layer, mapped in enumerate(num_hidden_layers): - mapping[str(mapped)] = str(num_layer) - final_kwargs["num_hidden_layers"] = len(num_hidden_layers) - else: - final_kwargs["num_hidden_layers"] = num_hidden_layers - - return super().from_pretrained_module( - pretrained_module, - source=source, - mapping=mapping, - load_weights=load_weights, - **final_kwargs, - ) + return cls(**final_kwargs) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index 67a623f98e3..d25239b27f3 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -4,11 +4,12 @@ import copy from collections import defaultdict, OrderedDict +from itertools import chain import json import logging from os import PathLike import re -from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, NamedTuple import math import numpy @@ -16,11 +17,18 @@ import torch.distributed as dist from allennlp.common.checks import ConfigurationError -from allennlp.common.util import int_to_device, is_distributed +from allennlp.common.util import int_to_device, is_distributed, is_global_primary logger = logging.getLogger(__name__) T = TypeVar("T") +StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] + +_MODULE_SHARDED_FLAG = "_is_sharded_allennlp" +""" +This flag is used to indicate when a module's parameters have been sharded across +distributed workers. +""" def move_to_device(obj, device: Union[torch.device, int]): @@ -926,7 +934,7 @@ def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage: return inner_device_mapping -def load_state_dict( +def read_state_dict( path: Union[PathLike, str], strip_prefix: Optional[str] = None, ignore: Optional[List[str]] = None, @@ -934,7 +942,7 @@ def load_state_dict( cuda_device: int = -1, ) -> Dict[str, torch.Tensor]: """ - Load a PyTorch model state dictionary from a checkpoint at the given `path`. + Read a PyTorch model state dictionary from a checkpoint at the given `path`. # Parameters @@ -2110,6 +2118,17 @@ def tiny_value_of_dtype(dtype: torch.dtype): _V = TypeVar("_V", int, float, torch.Tensor) +def distributed_device() -> torch.device: + """ + Get the correct `torch.device` of the current process to use for distributed point-to-point communication. + """ + if not is_distributed(): + raise RuntimeError( + "'distributed_device()' can only be called within a distributed process group" + ) + return int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + + def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ Reduces the given `value` across all distributed worker nodes according the given @@ -2134,7 +2153,7 @@ def dist_reduce(value: _V, reduce_op, **kwargs) -> _V: """ if not is_distributed(): return value - device = int_to_device(-1 if dist.get_backend() != "nccl" else torch.cuda.current_device()) + device = distributed_device() value_tensor = torch.tensor(value, device=device, **kwargs) dist.all_reduce(value_tensor, op=reduce_op) @@ -2157,3 +2176,191 @@ def dist_reduce_sum(value: _V, **kwargs) -> _V: if not is_distributed(): return value return dist_reduce(value, dist.ReduceOp.SUM, **kwargs) + + +def _collect_state_dict( + module: torch.nn.Module, state_dict: Optional[StateDictType], recurse: bool = True +) -> Tuple[StateDictType, List[str], List[str]]: + """ + Collect a module's state dict across distributed processes. + + Returns the syncronized state dictionary, which will always be a valid state dict, + and then the missing and unexpected keys corresponding to the original `state_dict`. + Parameters that missing from the original `state_dict` will be populated from the + corresponding parameter in the primary processes' module's state dict. + + !!! Note + + `missing_keys` and `unexpected_keys` are only populated in the primary process. + """ + # This is the device we'll use for the broadcast operation. + dist_device = distributed_device() + # This is the device we'll put all tensors on in the returned state dict. + state_dict_device = ( + int_to_device(-1) if not state_dict else state_dict[list(state_dict.keys())[0]].device + ) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + + # Gather current state dict and prepare to iterator over it. + # We iterate over this state dict instead of `state_dict` so we can be sure + # that the order is consistent across processes. + # We'll also update this state dict as we go and return it at the end. + if recurse: + current_state_dict = module.state_dict() + else: + # Only collect state of direct members, including both parameters and buffers. + current_state_dict = OrderedDict( + chain( + # Paramaters + ((n, p.data) for (n, p) in module.named_parameters(recurse=False)), + # Buffers + module.named_buffers(recurse=False), + ) + ) + + keys = list(current_state_dict.keys()) + + # Gather unexpected_keys. + if is_global_primary(): + assert state_dict is not None + module_keys = set(module.state_dict().keys()) + for key in state_dict: + if key not in module_keys: + unexpected_keys.append(key) + + for key in keys: + tensor = current_state_dict[key] + if is_global_primary(): + assert state_dict is not None + if key in state_dict: + # Update `tensor` to the value in `state_dict`. + tensor = state_dict[key] + else: + missing_keys.append(key) + tensor = tensor.to(dist_device) + dist.broadcast(tensor, 0) + current_state_dict[key] = tensor.to(state_dict_device) + + return current_state_dict, missing_keys, unexpected_keys + + +class _LoadStateDictResult(NamedTuple): + missing_keys: List[str] + unexpected_keys: List[str] + + +def load_state_dict_distributed( + module: torch.nn.Module, state_dict: Optional[StateDictType], strict: bool = True +) -> _LoadStateDictResult: + """ + Load a `state_dict` to the `module` within a distributed process. Only the global + primary process requires the `state_dict` to not be `None`. All other processes + will have the state tensors broadcasted to them one-by-one. + + If `strict` is `True`, then the keys of `state_dict` must exactly match the keys + returned by `module.state_dict()`. + + !!! Note + The returned `missing_keys` and `unexpected_keys` will only be accurate + in the primary process. + + # Returns + + `_LoadStateDictResult` + A `NamedTuple` with `missing_keys` and `unexpected_keys` fields, both of which + are lists of strings. + + # Raises + + `RuntimeError` + If `strict` is `True` and there are missing or unexpected keys. + + """ + if not is_distributed(): + return module.load_state_dict(state_dict, strict=strict) + + if is_global_primary(): + assert state_dict is not None + else: + assert state_dict is None + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + + submodules = dict(module.named_children()) + + def update_key_list(original, updates): + for key in updates: + if key not in original: + original.append(key) + + # If we've found a sharded module or there aren't any more submodules of the current module, + # we collect the state_dict and load it now instead of recursing further. + if getattr(module, _MODULE_SHARDED_FLAG, False) or not submodules: + # Collect. + state_dict, _missing_keys, _unexpected_keys = _collect_state_dict(module, state_dict) + assert state_dict is not None + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + # And load. + _missing_keys, _unexpected_keys = module.load_state_dict(state_dict, strict=False) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + else: + # We'll recursively call this function on each submodule, but first we need + # to collect any parameters that are direct members of this module. + direct_member_state_dict, _missing_keys, _unexpected_keys = _collect_state_dict( + module, state_dict, recurse=False + ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + + # `_missing_keys` here will contain any keys corresponding to submodules, but + # we'll remove those below. + _missing_keys, _unexpected_keys = module.load_state_dict( + direct_member_state_dict, strict=False + ) + update_key_list(missing_keys, _missing_keys) + update_key_list(unexpected_keys, _unexpected_keys) + + # Okay, now for the recursive part. + for name, submodule in submodules.items(): + # Update `missing_keys` to remove keys corresponding to this submodule. + # If they are actually missing after this step, we add them back in below. + missing_keys = [k for k in missing_keys if not k.startswith(name + ".")] + submodule_state_dict: Optional[StateDictType] = None + if is_global_primary(): + assert state_dict is not None + submodule_state_dict = { + key.replace(name + ".", "", 1): value + for key, value in state_dict.items() + if key.startswith(name + ".") + } + _missing_keys, _unexpected_keys = load_state_dict_distributed( + submodule, submodule_state_dict, strict=False + ) + update_key_list(missing_keys, [f"{name}.{key}" for key in _missing_keys]) + update_key_list(unexpected_keys, [f"{name}.{key}" for key in _unexpected_keys]) + + if strict: + error_msgs: List[str] = [] + if missing_keys: + error_msgs.append( + "Missing key(s) in state_dict: {}".format(", ".join(f'"{k}"' for k in missing_keys)) + ) + if unexpected_keys: + error_msgs.append( + "Unexpected key(s) in state_dict: {}".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ) + ) + if error_msgs: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + module.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + + return _LoadStateDictResult(missing_keys, unexpected_keys) diff --git a/scripts/py2md.py b/scripts/py2md.py index 82a31565485..c8bc1ca1d43 100755 --- a/scripts/py2md.py +++ b/scripts/py2md.py @@ -279,6 +279,13 @@ class AllenNlpFilterProcessor(Struct): "__call__", "__iter__", "InfluenceInterpreter._calculate_influence_scores", + "TransformerModule._from_config", + "TransformerModule._pretrained_mapping", + "TransformerModule._pretrained_relevant_module", + "TransformerModule._pretrained_ignore", + "TransformerModule._pretrained_allow_missing", + "TransformerModule._distributed_loading_strategy", + "TransformerModule._tied_weights", } def process(self, graph, _resolver): diff --git a/tests/modules/transformer/activation_layer_test.py b/tests/modules/transformer/activation_layer_test.py index 8c1b7ebef26..2af0338a92e 100644 --- a/tests/modules/transformer/activation_layer_test.py +++ b/tests/modules/transformer/activation_layer_test.py @@ -1,32 +1,34 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import ActivationLayer -from allennlp.common.testing import AllenNlpTestCase -class TestActivationLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params_dict(): + return { + "hidden_size": 5, + "intermediate_size": 3, + "activation": "relu", + } - self.params_dict = { - "hidden_size": 5, - "intermediate_size": 3, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +@pytest.fixture +def params(params_dict): + return Params(params_dict) - self.activation_layer = ActivationLayer.from_params(params) - def test_can_construct_from_params(self): +@pytest.fixture +def activation_layer(params): + return ActivationLayer.from_params(params.duplicate()) - activation_layer = self.activation_layer - assert activation_layer.dense.in_features == self.params_dict["hidden_size"] - assert activation_layer.dense.out_features == self.params_dict["intermediate_size"] +def test_can_construct_from_params(activation_layer, params_dict): + activation_layer = activation_layer + assert activation_layer.dense.in_features == params_dict["hidden_size"] + assert activation_layer.dense.out_features == params_dict["intermediate_size"] - def test_forward_runs(self): - self.activation_layer.forward(torch.randn(7, 5)) +def test_forward_runs(activation_layer): + activation_layer.forward(torch.randn(7, 5)) diff --git a/tests/modules/transformer/bimodal_attention_test.py b/tests/modules/transformer/bimodal_attention_test.py index 40dc81f12de..270aefd23e7 100644 --- a/tests/modules/transformer/bimodal_attention_test.py +++ b/tests/modules/transformer/bimodal_attention_test.py @@ -1,55 +1,56 @@ -import copy import torch +import pytest from allennlp.common import Params from allennlp.modules.transformer import BiModalAttention -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "hidden_size1": 6, - "hidden_size2": 4, - "combined_hidden_size": 16, - "num_attention_heads": 2, - "dropout1": 0.1, - "dropout2": 0.2, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.biattention = BiModalAttention.from_params(params) - - def test_can_construct_from_params(self): - - biattention = self.biattention - - assert biattention.num_attention_heads == self.params_dict["num_attention_heads"] - assert biattention.attention_head_size == int( - self.params_dict["combined_hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - biattention.all_head_size - == self.params_dict["num_attention_heads"] * biattention.attention_head_size - ) - assert biattention.query1.in_features == self.params_dict["hidden_size1"] - assert biattention.key1.in_features == self.params_dict["hidden_size1"] - assert biattention.value1.in_features == self.params_dict["hidden_size1"] - assert biattention.dropout1.p == self.params_dict["dropout1"] - - assert biattention.query2.in_features == self.params_dict["hidden_size2"] - assert biattention.key2.in_features == self.params_dict["hidden_size2"] - assert biattention.value2.in_features == self.params_dict["hidden_size2"] - assert biattention.dropout2.p == self.params_dict["dropout2"] - - def test_forward_runs(self): - - self.biattention.forward( - torch.randn(2, 3, 6), - torch.randn(2, 3, 4), - torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors - torch.randint(0, 2, (2, 2, 3, 3)) == 1, - ) + + +@pytest.fixture +def params_dict(): + return { + "hidden_size1": 6, + "hidden_size2": 4, + "combined_hidden_size": 16, + "num_attention_heads": 2, + "dropout1": 0.1, + "dropout2": 0.2, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def biattention(params): + return BiModalAttention.from_params(params.duplicate()) + + +def test_can_construct_from_params(biattention, params_dict): + assert biattention.num_attention_heads == params_dict["num_attention_heads"] + assert biattention.attention_head_size == int( + params_dict["combined_hidden_size"] / params_dict["num_attention_heads"] + ) + assert ( + biattention.all_head_size + == params_dict["num_attention_heads"] * biattention.attention_head_size + ) + assert biattention.query1.in_features == params_dict["hidden_size1"] + assert biattention.key1.in_features == params_dict["hidden_size1"] + assert biattention.value1.in_features == params_dict["hidden_size1"] + assert biattention.dropout1.p == params_dict["dropout1"] + + assert biattention.query2.in_features == params_dict["hidden_size2"] + assert biattention.key2.in_features == params_dict["hidden_size2"] + assert biattention.value2.in_features == params_dict["hidden_size2"] + assert biattention.dropout2.p == params_dict["dropout2"] + + +def test_forward_runs(biattention): + biattention( + torch.randn(2, 3, 6), + torch.randn(2, 3, 4), + torch.randint(0, 2, (2, 2, 3, 3)) == 1, # creating boolean tensors + torch.randint(0, 2, (2, 2, 3, 3)) == 1, + ) diff --git a/tests/modules/transformer/bimodal_encoder_test.py b/tests/modules/transformer/bimodal_encoder_test.py index b95af3bfa1f..39bd3b54e8c 100644 --- a/tests/modules/transformer/bimodal_encoder_test.py +++ b/tests/modules/transformer/bimodal_encoder_test.py @@ -1,95 +1,92 @@ -import copy import torch +from torch.testing import assert_allclose +from transformers import AutoModel +import pytest + from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import BiModalEncoder -from allennlp.common.testing import AllenNlpTestCase - - -class TestBiModalEncoder(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = { - "num_hidden_layers1": 3, - "num_hidden_layers2": 3, - "hidden_size1": 12, - "hidden_size2": 12, - "combined_hidden_size": 12, - "intermediate_size1": 3, - "intermediate_size2": 3, - "num_attention_heads1": 4, - "num_attention_heads2": 6, - "combined_num_attention_heads": 2, - "attention_dropout1": 0.1, - "hidden_dropout1": 0.2, - "attention_dropout2": 0.1, - "hidden_dropout2": 0.2, - "activation": "relu", - "biattention_id1": [1, 2], - "biattention_id2": [1, 2], - "fixed_layer1": 1, - "fixed_layer2": 1, - } - - params = Params(copy.deepcopy(self.params_dict)) - - self.bimodal_encoder = BiModalEncoder.from_params(params) - - self.pretrained = cached_transformers.get("bert-base-uncased", False) - - def test_can_construct_from_params(self): - - modules = dict(self.bimodal_encoder.named_modules()) - assert len(modules["layers1"]) == self.params_dict["num_hidden_layers1"] - assert len(modules["layers2"]) == self.params_dict["num_hidden_layers2"] - - def test_forward_runs(self): - - embedding1 = torch.randn(16, 34, self.params_dict["hidden_size1"]) - embedding2 = torch.randn(16, 2, self.params_dict["hidden_size2"]) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - self.bimodal_encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - required_kwargs = [ - "num_hidden_layers2", - "hidden_size2", - "combined_hidden_size", - "intermediate_size2", - "num_attention_heads2", - "combined_num_attention_heads", - "attention_dropout2", - "hidden_dropout2", - "biattention_id1", - "biattention_id2", - "fixed_layer1", - "fixed_layer2", - ] - kwargs = {key: self.params_dict[key] for key in required_kwargs} - module = BiModalEncoder.from_pretrained_module(pretrained_module, **kwargs) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters( - pretrained_module, - module, - ignore_missing=True, - mapping=mapping, - ) - - def test_default_parameters(self): - encoder = BiModalEncoder() - embedding1 = torch.randn(16, 34, 1024) - embedding2 = torch.randn(16, 2, 1024) - attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 - attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 - - encoder.forward(embedding1, embedding2, attn_mask1, attn_mask2) + + +@pytest.fixture +def params_dict(): + return { + "num_hidden_layers1": 3, + "num_hidden_layers2": 3, + "hidden_size1": 12, + "hidden_size2": 12, + "combined_hidden_size": 12, + "intermediate_size1": 3, + "intermediate_size2": 3, + "num_attention_heads1": 4, + "num_attention_heads2": 6, + "combined_num_attention_heads": 2, + "attention_dropout1": 0.1, + "hidden_dropout1": 0.2, + "attention_dropout2": 0.1, + "hidden_dropout2": 0.2, + "activation": "relu", + "biattention_id1": [1, 2], + "biattention_id2": [1, 2], + "fixed_layer1": 1, + "fixed_layer2": 1, + } + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def bimodal_encoder(params): + return BiModalEncoder.from_params(params.duplicate()) + + +def test_can_construct_from_params(bimodal_encoder, params_dict): + modules = dict(bimodal_encoder.named_modules()) + assert len(modules["layers1"]) == params_dict["num_hidden_layers1"] + assert len(modules["layers2"]) == params_dict["num_hidden_layers2"] + + +def test_forward_runs(bimodal_encoder, params_dict): + embedding1 = torch.randn(16, 34, params_dict["hidden_size1"]) + embedding2 = torch.randn(16, 2, params_dict["hidden_size2"]) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + bimodal_encoder(embedding1, embedding2, attn_mask1, attn_mask2) + + +def test_loading_from_pretrained_weights(params_dict): + pretrained_module = AutoModel.from_pretrained("bert-base-cased").encoder + + required_kwargs = [ + "num_hidden_layers2", + "hidden_size2", + "combined_hidden_size", + "intermediate_size2", + "num_attention_heads2", + "combined_num_attention_heads", + "attention_dropout2", + "hidden_dropout2", + "biattention_id1", + "biattention_id2", + "fixed_layer1", + "fixed_layer2", + ] + kwargs = {key: params_dict[key] for key in required_kwargs} + + module = BiModalEncoder.from_pretrained_module("bert-base-cased", **kwargs) + assert_allclose( + module.layers1[0].intermediate.dense.weight.data, + pretrained_module.layer[0].intermediate.dense.weight.data, + ) + + +def test_default_parameters(): + encoder = BiModalEncoder() + embedding1 = torch.randn(16, 34, 1024) + embedding2 = torch.randn(16, 2, 1024) + attn_mask1 = torch.randint(0, 2, (16, 1, 1, 34)) == 1 + attn_mask2 = torch.randint(0, 2, (16, 1, 1, 2)) == 1 + + encoder(embedding1, embedding2, attn_mask1, attn_mask2) diff --git a/tests/modules/transformer/self_attention_test.py b/tests/modules/transformer/self_attention_test.py index e29ae44cf9e..7a3dcb81ec8 100644 --- a/tests/modules/transformer/self_attention_test.py +++ b/tests/modules/transformer/self_attention_test.py @@ -1,21 +1,13 @@ import copy + import torch import pytest +from transformers import AutoModel from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters, AllenNlpTestCase from allennlp.modules.transformer import SelfAttention from allennlp.nn.util import min_value_of_dtype -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertSelfAttention -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaSelfAttention -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraSelfAttention -from transformers.models.distilbert.configuration_distilbert import DistilBertConfig -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention PARAMS_DICT = { "hidden_size": 6, @@ -24,145 +16,78 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("dropout") +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) - # bert, roberta, electra self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertSelfAttention(BertConfig(**params)) - modules["bert"] = hf_module +@pytest.fixture +def params(params_dict): + return Params(params_dict) - torch.manual_seed(1234) - hf_module = RobertaSelfAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module - torch.manual_seed(1234) - hf_module = ElectraSelfAttention(ElectraConfig(**params)) - modules["electra"] = hf_module +@pytest.fixture +def self_attention(params): + return SelfAttention.from_params(params.duplicate()) - torch.manual_seed(1234) - distilparams = copy.deepcopy(params_dict) - distilparams["n_heads"] = distilparams.pop("num_attention_heads") - distilparams["dim"] = distilparams.pop("hidden_size") - distilparams["attention_dropout"] = distilparams.pop("dropout") - hf_module = MultiHeadSelfAttention(DistilBertConfig(**distilparams)) - modules["distilbert"] = hf_module - return modules - - -class TestSelfAttention(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} +def test_can_construct_from_params(self_attention, params_dict): + assert self_attention.num_attention_heads == params_dict["num_attention_heads"] + assert self_attention.attention_head_size == int( + params_dict["hidden_size"] / params_dict["num_attention_heads"] + ) - params = Params(copy.deepcopy(self.params_dict)) + assert ( + self_attention.all_head_size + == params_dict["num_attention_heads"] * self_attention.attention_head_size + ) - self.self_attention = SelfAttention.from_params(params) + assert self_attention.query.in_features == params_dict["hidden_size"] + assert self_attention.key.in_features == params_dict["hidden_size"] + assert self_attention.value.in_features == params_dict["hidden_size"] - def test_can_construct_from_params(self): - assert self.self_attention.num_attention_heads == self.params_dict["num_attention_heads"] - assert self.self_attention.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) + assert self_attention.dropout.p == params_dict["dropout"] - assert ( - self.self_attention.all_head_size - == self.params_dict["num_attention_heads"] * self.self_attention.attention_head_size - ) - assert self.self_attention.query.in_features == self.params_dict["hidden_size"] - assert self.self_attention.key.in_features == self.params_dict["hidden_size"] - assert self.self_attention.value.in_features == self.params_dict["hidden_size"] +@pytest.mark.parametrize( + "pretrained_name, relevant_module", + [ + ("bert-base-cased", "bert.encoder.layer.0.attention.self"), + ("google/electra-base-generator", "electra.encoder.layer.0.attention.self"), + ("distilbert-base-uncased", "distilbert.transformer.layer.0.attention"), + ], +) +def test_loading_from_pretrained_weights_using_model_name(pretrained_name, relevant_module): + torch.manual_seed(1234) + module = SelfAttention.from_pretrained_module(pretrained_name, relevant_module=relevant_module) - assert self.self_attention.dropout.p == self.params_dict["dropout"] + torch.manual_seed(1234) + pretrained_module = dict(AutoModel.from_pretrained(pretrained_name).named_modules())[ + # Module name will exclude the top-level part (e.g. 'bert.', 'electra.') for some reason. + relevant_module[relevant_module.index(".") + 1 :] + ] - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + batch_size = 2 + seq_len = 3 + dim = module.query.in_features + hidden_states = torch.randn(batch_size, seq_len, dim) + attention_mask = torch.tensor([[1, 1, 0], [1, 0, 1]])[:, None, None, :] - torch.manual_seed(1234) - self_attention = SelfAttention.from_pretrained_module(hf_module) - - output = self_attention.forward(hidden_states, attention_mask=attention_mask) - if module_name == "distilbert": - hf_output = hf_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - ) - else: - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - "google/electra-base-generator", - "distilbert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): + # setting to eval mode to avoid non-deterministic dropout. + module = module.eval() + pretrained_module = pretrained_module.eval() + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + if "distilbert" in pretrained_name: torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - # Get the self attention layer. - if "distilbert" in pretrained_name: - pretrained_module = pretrained_module.attention - else: - pretrained_module = pretrained_module.attention.self - + hf_output = pretrained_module( + hidden_states, hidden_states, hidden_states, mask=attention_mask + )[0] + else: + # The attn_mask is processed outside the self attention module in HF bert models. + attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) torch.manual_seed(1234) - module = SelfAttention.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 3 - dim = module.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, 1, 1, seq_len)) - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask)[0] - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - if "distilbert" in pretrained_name: - torch.manual_seed(1234) - hf_output = pretrained_module.forward( - hidden_states, hidden_states, hidden_states, mask=attention_mask - )[0] - else: - # The attn_mask is processed outside the self attention module in HF bert models. - attention_mask = (~(attention_mask == 1)) * min_value_of_dtype(hidden_states.dtype) - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask)[0] - - assert torch.allclose(output, hf_output) + assert torch.allclose(output, hf_output) diff --git a/tests/modules/transformer/toolkit_test.py b/tests/modules/transformer/toolkit_test.py index cd1bf60e9fd..ff59b9cf6b5 100644 --- a/tests/modules/transformer/toolkit_test.py +++ b/tests/modules/transformer/toolkit_test.py @@ -1,9 +1,10 @@ import torch +from torch.testing import assert_allclose from overrides import overrides +from transformers import AutoModel from transformers.models.albert.modeling_albert import AlbertEmbeddings from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters from allennlp.data.vocabulary import Vocabulary from allennlp.modules.token_embedders import Embedding, TokenEmbedder from allennlp.modules.transformer import TransformerStack, TransformerEmbeddings @@ -49,15 +50,19 @@ def forward(self, token_ids: torch.LongTensor): tiny.forward(torch.LongTensor([[0, 1, 2]])) def test_use_first_four_layers_of_pretrained(self): - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = "bert-base-cased" class SmallTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module(pretrained) - + self.embeddings = TransformerEmbeddings.from_pretrained_module( + pretrained, relevant_module="bert.embeddings" + ) self.transformer = TransformerStack.from_pretrained_module( - pretrained, num_hidden_layers=4 + pretrained, + num_hidden_layers=4, + relevant_module="bert.encoder", + strict=False, ) @overrides @@ -68,19 +73,27 @@ def forward(self, token_ids: torch.LongTensor): small = SmallTransformer() assert len(small.transformer.layers) == 4 - small.forward(torch.LongTensor([[0, 1, 2]])) + small(torch.LongTensor([[0, 1, 2]])) def test_use_selected_layers_of_bert_for_different_purposes(self): class MediumTransformer(torch.nn.Module): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.from_pretrained_module("bert-base-uncased") + self.embeddings = TransformerEmbeddings.from_pretrained_module( + "bert-base-cased", relevant_module="bert.embeddings" + ) self.separate_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", num_hidden_layers=range(0, 8) + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=8, + strict=False, ) self.combined_transformer = TransformerStack.from_pretrained_module( - "bert-base-uncased", - num_hidden_layers=range(8, 12), + "bert-base-cased", + relevant_module="bert.encoder", + num_hidden_layers=4, + mapping={f"layer.{l}": f"layers.{i}" for (i, l) in enumerate(range(8, 12))}, + strict=False, ) @overrides @@ -106,22 +119,31 @@ def forward( assert (len(medium.separate_transformer.layers)) == 8 assert (len(medium.combined_transformer.layers)) == 4 - pretrained = cached_transformers.get("bert-base-uncased", False) + pretrained = cached_transformers.get("bert-base-cased", False) pretrained_layers = dict(pretrained.encoder.layer.named_modules()) - medium_layers = dict(medium.combined_transformer.layers.named_modules()) + separate_layers = dict(medium.separate_transformer.layers.named_modules()) + assert_allclose( + separate_layers["0"].intermediate.dense.weight.data, + pretrained_layers["0"].intermediate.dense.weight.data, + ) - assert_equal_parameters( - medium_layers["0"], pretrained_layers["8"], TransformerStack._huggingface_mapping + combined_layers = dict(medium.combined_transformer.layers.named_modules()) + assert_allclose( + combined_layers["0"].intermediate.dense.weight.data, + pretrained_layers["8"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["1"], pretrained_layers["9"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["1"].intermediate.dense.weight.data, + pretrained_layers["9"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["2"], pretrained_layers["10"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["2"].intermediate.dense.weight.data, + pretrained_layers["10"].intermediate.dense.weight.data, ) - assert_equal_parameters( - medium_layers["3"], pretrained_layers["11"], TransformerStack._huggingface_mapping + assert_allclose( + combined_layers["3"].intermediate.dense.weight.data, + pretrained_layers["11"].intermediate.dense.weight.data, ) def test_combination_of_two_different_berts(self): @@ -130,8 +152,10 @@ def test_combination_of_two_different_berts(self): class AlmostRegularTransformer(TokenEmbedder): def __init__(self): super().__init__() - self.embeddings = TransformerEmbeddings.get_relevant_module("albert-base-v2") - self.transformer = TransformerStack.from_pretrained_module("bert-base-uncased") + self.embeddings = AutoModel.from_pretrained("albert-base-v2").embeddings + self.transformer = TransformerStack.from_pretrained_module( + "bert-base-cased", relevant_module="bert.encoder" + ) # We want to tune only the embeddings, because that's our experiment. self.transformer.requires_grad = False diff --git a/tests/modules/transformer/transformer_embeddings_test.py b/tests/modules/transformer/transformer_embeddings_test.py index d366f4732b4..d37eae8629b 100644 --- a/tests/modules/transformer/transformer_embeddings_test.py +++ b/tests/modules/transformer/transformer_embeddings_test.py @@ -1,23 +1,21 @@ -import pytest import copy + +import pytest import torch from torch.testing import assert_allclose - -from allennlp.common import Params, FromParams -from allennlp.common import cached_transformers - +from transformers import AutoModel from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertEmbeddings from transformers.models.albert.configuration_albert import AlbertConfig from transformers.models.albert.modeling_albert import AlbertEmbeddings -from allennlp.common.testing import assert_equal_parameters +from allennlp.common import Params, FromParams from allennlp.modules.transformer import ( TransformerEmbeddings, ImageFeatureEmbeddings, TransformerModule, ) -from allennlp.common.testing import AllenNlpTestCase + PARAMS_DICT = { "vocab_size": 20, @@ -29,9 +27,159 @@ } -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def params_dict(): + return copy.deepcopy(PARAMS_DICT) + + +@pytest.fixture +def params(params_dict): + return Params(params_dict) + + +@pytest.fixture +def transformer_embeddings(params): + return TransformerEmbeddings.from_params(params.duplicate()) + + +def test_can_construct_from_params(params_dict, transformer_embeddings): + embeddings = transformer_embeddings.embeddings + assert embeddings.word_embeddings.num_embeddings == params_dict["vocab_size"] + assert embeddings.word_embeddings.embedding_dim == params_dict["embedding_size"] + assert embeddings.word_embeddings.padding_idx == params_dict["pad_token_id"] + + assert embeddings.position_embeddings.num_embeddings == params_dict["max_position_embeddings"] + assert embeddings.position_embeddings.embedding_dim == params_dict["embedding_size"] + + assert embeddings.token_type_embeddings.num_embeddings == params_dict["type_vocab_size"] + assert embeddings.token_type_embeddings.embedding_dim == params_dict["embedding_size"] + + assert transformer_embeddings.layer_norm.normalized_shape[0] == params_dict["embedding_size"] + + assert transformer_embeddings.dropout.p == params_dict["dropout"] + + +def test_sanity(): + class TextEmbeddings(TransformerModule, FromParams): + def __init__( + self, + vocab_size: int, + hidden_size: int, + pad_token_id: int, + max_position_embeddings: int, + type_vocab_size: int, + dropout: float, + ): + super().__init__() + self.word_embeddings = torch.nn.Embedding( + vocab_size, hidden_size, padding_idx=pad_token_id + ) + self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) + + self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + torch.manual_seed(23) + text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) + torch.manual_seed(23) + transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) + + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + text_output = text(input_ids, token_type_ids, position_ids) + transformer_output = transformer(input_ids, token_type_ids, position_ids) + + assert_allclose(text_output, transformer_output) + + +def test_forward_runs_with_inputs(transformer_embeddings): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + +def test_output_size(params): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + params["output_size"] = 7 + module = TransformerEmbeddings.from_params(params) + output = module(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids) + + assert output.shape[-1] == 7 + + +def test_no_token_type_layer(params): + params["type_vocab_size"] = 0 + module = TransformerEmbeddings.from_params(params) + assert len(module.embeddings) == 2 + + +@pytest.mark.parametrize( + "pretrained_name", + [ + "bert-base-cased", + "epwalsh/bert-xsmall-dummy", + ], +) +def test_loading_from_pretrained_module(pretrained_name): + TransformerEmbeddings.from_pretrained_module(pretrained_name) + + +def test_loading_albert(): + """ + Albert is a special case because it includes a Linear layer in the encoder + that maps the embeddings to the encoder hidden size, but we include this linear + layer within our embedding layer. + """ + transformer_embedding = TransformerEmbeddings.from_pretrained_module( + "albert-base-v2", + ) + albert = AutoModel.from_pretrained("albert-base-v2") + assert_allclose( + transformer_embedding.embeddings.word_embeddings.weight.data, + albert.embeddings.word_embeddings.weight.data, + ) + assert_allclose( + transformer_embedding.linear_transform.weight.data, + albert.encoder.embedding_hidden_mapping_in.weight.data, + ) + + +def get_modules(): + params = copy.deepcopy(PARAMS_DICT) params["hidden_dropout_prob"] = params.pop("dropout") params["hidden_size"] = params.pop("embedding_size") @@ -39,270 +187,117 @@ def get_modules(params_dict): # bert, roberta, electra self attentions have the same code. torch.manual_seed(1234) - hf_module = BertEmbeddings(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertEmbeddings(BertConfig(**params)) - albertparams = copy.deepcopy(params_dict) + albertparams = copy.deepcopy(PARAMS_DICT) albertparams["hidden_dropout_prob"] = albertparams.pop("dropout") torch.manual_seed(1234) - hf_module = AlbertEmbeddings(AlbertConfig(**albertparams)) - modules["albert"] = hf_module - - return modules - - -class TestTransformerEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {key: val for key, val in PARAMS_DICT.items()} - - params = Params(copy.deepcopy(self.params_dict)) - - self.transformer_embeddings = TransformerEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - - transformer_embeddings = self.transformer_embeddings.embeddings - - assert ( - transformer_embeddings.word_embeddings.num_embeddings == self.params_dict["vocab_size"] - ) - assert ( - transformer_embeddings.word_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - assert ( - transformer_embeddings.word_embeddings.padding_idx == self.params_dict["pad_token_id"] - ) - - assert ( - transformer_embeddings.position_embeddings.num_embeddings - == self.params_dict["max_position_embeddings"] - ) - assert ( - transformer_embeddings.position_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - transformer_embeddings.token_type_embeddings.num_embeddings - == self.params_dict["type_vocab_size"] - ) - assert ( - transformer_embeddings.token_type_embeddings.embedding_dim - == self.params_dict["embedding_size"] - ) - - assert ( - self.transformer_embeddings.layer_norm.normalized_shape[0] - == self.params_dict["embedding_size"] - ) - - assert self.transformer_embeddings.dropout.p == self.params_dict["dropout"] - - def test_sanity(self): - class TextEmbeddings(TransformerModule, FromParams): - def __init__( - self, - vocab_size: int, - hidden_size: int, - pad_token_id: int, - max_position_embeddings: int, - type_vocab_size: int, - dropout: float, - ): - super().__init__() - self.word_embeddings = torch.nn.Embedding( - vocab_size, hidden_size, padding_idx=pad_token_id - ) - self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size) - - self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None - ): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - device = input_ids.device if input_ids is not None else inputs_embeds.device - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand(input_shape) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = inputs_embeds + position_embeddings + token_type_embeddings - embeddings = self.layer_norm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - torch.manual_seed(23) - text = TextEmbeddings(10, 5, 2, 3, 7, 0.0) - torch.manual_seed(23) - transformer = TransformerEmbeddings(10, 5, 2, 3, 7, 0.0) - - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - text_output = text.forward(input_ids, token_type_ids, position_ids) - transformer_output = transformer.forward(input_ids, token_type_ids, position_ids) - - assert_allclose(text_output, transformer_output) - - def test_forward_runs_with_inputs(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - self.transformer_embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - def test_output_size(self): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - params = copy.deepcopy(self.params_dict) - params["output_size"] = 7 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - output = module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert output.shape[-1] == 7 - - def test_no_token_type_layer(self): - params = copy.deepcopy(self.params_dict) - params["type_vocab_size"] = 0 - params = Params(params) - module = TransformerEmbeddings.from_params(params) - - assert len(module.embeddings) == 2 - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "albert-base-v2", - ], + yield "albert", AlbertEmbeddings(AlbertConfig(**albertparams)) + + +@pytest.mark.parametrize("module_name, hf_module", get_modules()) +def test_forward_against_huggingface_output(transformer_embeddings, module_name, hf_module): + input_ids = torch.tensor([[1, 2]]) + token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) + position_ids = torch.tensor([[0, 1]]) + + state_dict = transformer_embeddings._get_mapped_state_dict(hf_module.state_dict()) + if "position_ids" in state_dict: + del state_dict["position_ids"] + transformer_embeddings.load_state_dict(state_dict) + + torch.manual_seed(1234) + transformer_embeddings = ( + transformer_embeddings.eval() + ) # setting to eval mode to avoid non-deterministic dropout. + output = transformer_embeddings( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - pretrained_module = cached_transformers.get(pretrained_name, False).embeddings - module = TransformerEmbeddings.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - missing = assert_equal_parameters(pretrained_module, module, mapping=mapping) - assert len(missing) == 0 - - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_output(self, module_name, hf_module): - input_ids = torch.tensor([[1, 2]]) - token_type_ids = torch.tensor([[1, 0]], dtype=torch.long) - position_ids = torch.tensor([[0, 1]]) - - torch.manual_seed(1234) - embeddings = TransformerEmbeddings.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - embeddings = embeddings.eval() # setting to eval mode to avoid non-deterministic dropout. - output = embeddings.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - torch.manual_seed(1234) - hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. - hf_output = hf_module.forward( - input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids - ) - - assert torch.allclose(output, hf_output) - - -class TestImageFeatureEmbeddings(AllenNlpTestCase): - def setup_method(self): - super().setup_method() - - self.params_dict = {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} - - params = Params(copy.deepcopy(self.params_dict)) - - self.img_embeddings = ImageFeatureEmbeddings.from_params(params) - - def test_can_construct_from_params(self): - assert ( - self.img_embeddings.embeddings.image_embeddings.in_features - == self.params_dict["feature_size"] - ) - assert ( - self.img_embeddings.embeddings.image_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert ( - self.img_embeddings.embeddings.location_embeddings.out_features - == self.params_dict["embedding_size"] - ) - assert self.img_embeddings.dropout.p == self.params_dict["dropout"] - - def test_forward_runs_with_inputs(self): - batch_size = 2 - feature_dim = self.params_dict["feature_size"] - image_feature = torch.randn(batch_size, feature_dim) - image_location = torch.randn(batch_size, 4) - self.img_embeddings.forward(image_feature, image_location) - - def test_sanity(self): - class OldImageFeatureEmbeddings(TransformerModule, FromParams): - """Construct the embeddings from image, spatial location (omit now) and - token_type embeddings. - """ - - def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): - super().__init__() - - self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) - self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) - self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): - img_embeddings = self.image_embeddings(image_feature) - loc_embeddings = self.image_location_embeddings(image_location) - embeddings = self.layer_norm(img_embeddings + loc_embeddings) - embeddings = self.dropout(embeddings) - - return embeddings - - torch.manual_seed(23) - old = OldImageFeatureEmbeddings(**self.params_dict) - torch.manual_seed(23) - now = ImageFeatureEmbeddings(**self.params_dict) - - batch_size = 2 - - image_feature = torch.randn(batch_size, self.params_dict["feature_size"]) - image_location = torch.randn(batch_size, 4) - - torch.manual_seed(23) - old_output = old.forward(image_feature, image_location) - torch.manual_seed(23) - now_output = now.forward(image_feature, image_location) - - assert_allclose(old_output, now_output) + + torch.manual_seed(1234) + hf_module = hf_module.eval() # setting to eval mode to avoid non-deterministic dropout. + hf_output = hf_module( + input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids + ) + + assert torch.allclose(output, hf_output) + + +@pytest.fixture +def image_params_dict(): + return {"feature_size": 3, "embedding_size": 5, "dropout": 0.1} + + +@pytest.fixture +def image_params(image_params_dict): + return Params(image_params_dict) + + +@pytest.fixture +def image_embeddings(image_params): + return ImageFeatureEmbeddings.from_params(image_params.duplicate()) + + +def test_can_construct_image_embeddings_from_params(image_embeddings, image_params_dict): + assert ( + image_embeddings.embeddings.image_embeddings.in_features + == image_params_dict["feature_size"] + ) + assert ( + image_embeddings.embeddings.image_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert ( + image_embeddings.embeddings.location_embeddings.out_features + == image_params_dict["embedding_size"] + ) + assert image_embeddings.dropout.p == image_params_dict["dropout"] + + +def test_image_embedding_forward_runs_with_inputs(image_embeddings, image_params_dict): + batch_size = 2 + feature_dim = image_params_dict["feature_size"] + image_feature = torch.randn(batch_size, feature_dim) + image_location = torch.randn(batch_size, 4) + image_embeddings(image_feature, image_location) + + +def test_image_embeddings_sanity(image_params_dict): + class OldImageFeatureEmbeddings(TransformerModule, FromParams): + """Construct the embeddings from image, spatial location (omit now) and + token_type embeddings. + """ + + def __init__(self, feature_size: int, embedding_size: int, dropout: float = 0.0): + super().__init__() + + self.image_embeddings = torch.nn.Linear(feature_size, embedding_size) + self.image_location_embeddings = torch.nn.Linear(4, embedding_size, bias=False) + self.layer_norm = torch.nn.LayerNorm(embedding_size, eps=1e-12) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, image_feature: torch.Tensor, image_location: torch.Tensor): + img_embeddings = self.image_embeddings(image_feature) + loc_embeddings = self.image_location_embeddings(image_location) + embeddings = self.layer_norm(img_embeddings + loc_embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + torch.manual_seed(23) + old = OldImageFeatureEmbeddings(**image_params_dict) + torch.manual_seed(23) + now = ImageFeatureEmbeddings(**image_params_dict) + + batch_size = 2 + + image_feature = torch.randn(batch_size, image_params_dict["feature_size"]) + image_location = torch.randn(batch_size, 4) + + torch.manual_seed(23) + old_output = old(image_feature, image_location) + torch.manual_seed(23) + now_output = now(image_feature, image_location) + + assert_allclose(old_output, now_output) diff --git a/tests/modules/transformer/transformer_layer_test.py b/tests/modules/transformer/transformer_layer_test.py index 1ecf183eace..4c1e141a5a8 100644 --- a/tests/modules/transformer/transformer_layer_test.py +++ b/tests/modules/transformer/transformer_layer_test.py @@ -1,13 +1,7 @@ import copy + import torch import pytest - -from allennlp.common import Params -from allennlp.common import cached_transformers -from allennlp.common.testing import assert_equal_parameters -from allennlp.modules.transformer import AttentionLayer, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase - from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import BertAttention, BertLayer from transformers.models.roberta.configuration_roberta import RobertaConfig @@ -15,6 +9,14 @@ from transformers.models.electra.configuration_electra import ElectraConfig from transformers.models.electra.modeling_electra import ElectraAttention, ElectraLayer +from allennlp.common import Params, cached_transformers +from allennlp.common.testing import run_distributed_test +from allennlp.modules.transformer import ( + AttentionLayer, + TransformerLayer, +) + + ATTENTION_PARAMS_DICT = { "hidden_size": 6, "num_attention_heads": 2, @@ -23,141 +25,113 @@ } -def get_attention_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) +@pytest.fixture +def attention_params(): + return Params(copy.deepcopy(ATTENTION_PARAMS_DICT)) + + +def test_attention(attention_params): + attention_layer = AttentionLayer.from_params(attention_params.duplicate()).eval() + + assert attention_layer.self.num_attention_heads == attention_params["num_attention_heads"] + assert attention_layer.self.attention_head_size == int( + attention_params["hidden_size"] / attention_params["num_attention_heads"] + ) + assert ( + attention_layer.self.all_head_size + == attention_params["num_attention_heads"] * attention_layer.self.attention_head_size + ) + assert attention_layer.self.query.in_features == attention_params["hidden_size"] + assert attention_layer.self.key.in_features == attention_params["hidden_size"] + assert attention_layer.self.value.in_features == attention_params["hidden_size"] + assert attention_layer.self.dropout.p == attention_params["attention_dropout"] + + assert attention_layer.output.dense.in_features == attention_params["hidden_size"] + assert attention_layer.output.dense.out_features == attention_params["hidden_size"] + assert attention_layer.output.layer_norm.normalized_shape[0] == attention_params["hidden_size"] + assert attention_layer.output.dropout.p == attention_params["hidden_dropout"] + + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + attention_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) + + +def get_attention_modules(): + params = copy.deepcopy(ATTENTION_PARAMS_DICT) params["attention_probs_dropout_prob"] = params.pop("attention_dropout") params["hidden_dropout_prob"] = params.pop("hidden_dropout") torch.manual_seed(1234) - hf_module = BertAttention(BertConfig(**params)) - modules["bert"] = hf_module + yield "bert", BertAttention(BertConfig(**params)).eval() torch.manual_seed(1234) - hf_module = RobertaAttention(RobertaConfig(**params)) - modules["roberta"] = hf_module + yield "roberta", RobertaAttention(RobertaConfig(**params)).eval() torch.manual_seed(1234) - hf_module = ElectraAttention(ElectraConfig(**params)) - modules["electra"] = hf_module + yield "electra", ElectraAttention(ElectraConfig(**params)).eval() - return modules +@pytest.mark.parametrize("module_name, hf_module", get_attention_modules()) +def test_attention_matches_huggingface(attention_params, module_name, hf_module): + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) -class TestAttentionLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + attention = AttentionLayer.from_params(attention_params).eval() + state_dict = attention._get_mapped_state_dict(hf_module.state_dict()) + attention.load_state_dict(state_dict) - self.params_dict = { - "hidden_size": 6, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - } + torch.manual_seed(1234) + output = attention(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - params = Params(copy.deepcopy(self.params_dict)) + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - self.attention_layer = AttentionLayer.from_params(params) + assert torch.allclose(output[0], hf_output[0]) - def test_can_construct_from_params(self): - attention_layer = self.attention_layer +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_attention_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - assert attention_layer.self.num_attention_heads == self.params_dict["num_attention_heads"] - assert attention_layer.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - attention_layer.self.all_head_size - == self.params_dict["num_attention_heads"] * attention_layer.self.attention_head_size - ) - assert attention_layer.self.query.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.key.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.value.in_features == self.params_dict["hidden_size"] - assert attention_layer.self.dropout.p == self.params_dict["attention_dropout"] - - assert attention_layer.output.dense.in_features == self.params_dict["hidden_size"] - assert attention_layer.output.dense.out_features == self.params_dict["hidden_size"] - assert ( - attention_layer.output.layer_norm.normalized_shape[0] == self.params_dict["hidden_size"] - ) - assert attention_layer.output.dropout.p == self.params_dict["hidden_dropout"] + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1].attention - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.attention_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) + torch.manual_seed(1234) + module = AttentionLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0.attention", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.parametrize( - "module_name, hf_module", get_attention_modules(ATTENTION_PARAMS_DICT).items() - ) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - attention = AttentionLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = attention.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module.attention - - torch.manual_seed(1234) - module = AttentionLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] + + assert torch.allclose(output, hf_output, atol=1e-04) LAYER_PARAMS_DICT = { @@ -170,213 +144,158 @@ def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name) } -def get_layer_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") +@pytest.fixture +def layer_params(): + return Params(copy.deepcopy(LAYER_PARAMS_DICT)) - # bert, roberta, electra, layoutlm self attentions have the same code. - torch.manual_seed(1234) - hf_module = BertLayer(BertConfig(**params)) - modules["bert"] = hf_module +def test_layer(layer_params): + transformer_layer = TransformerLayer.from_params(layer_params.duplicate()).eval() - torch.manual_seed(1234) - hf_module = RobertaLayer(RobertaConfig(**params)) - modules["roberta"] = hf_module + assert ( + transformer_layer.attention.self.num_attention_heads == layer_params["num_attention_heads"] + ) + assert transformer_layer.attention.self.attention_head_size == int( + layer_params["hidden_size"] / layer_params["num_attention_heads"] + ) + assert ( + transformer_layer.attention.self.all_head_size + == layer_params["num_attention_heads"] + * transformer_layer.attention.self.attention_head_size + ) + assert transformer_layer.attention.self.query.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.key.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.value.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.self.dropout.p == layer_params["attention_dropout"] + + assert transformer_layer.attention.output.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.attention.output.dense.out_features == layer_params["hidden_size"] + assert ( + transformer_layer.attention.output.layer_norm.normalized_shape[0] + == layer_params["hidden_size"] + ) + assert transformer_layer.attention.output.dropout.p == layer_params["hidden_dropout"] - torch.manual_seed(1234) - hf_module = ElectraLayer(ElectraConfig(**params)) - modules["electra"] = hf_module + assert transformer_layer.intermediate.dense.in_features == layer_params["hidden_size"] + assert transformer_layer.intermediate.dense.out_features == layer_params["intermediate_size"] - return modules + assert transformer_layer.output.dense.in_features == layer_params["intermediate_size"] + assert transformer_layer.output.dense.out_features == layer_params["hidden_size"] + assert transformer_layer.output.layer_norm.normalized_shape[0] == layer_params["hidden_size"] -class TestTransformerLayer(AllenNlpTestCase): - def setup_method(self): - super().setup_method() + assert transformer_layer.output.dropout.p == layer_params["hidden_dropout"] - self.params_dict = { - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer(torch.randn(2, 3, 6), attention_mask=attention_mask) - params = Params(copy.deepcopy(self.params_dict)) + with pytest.raises(AssertionError): + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - self.transformer_layer = TransformerLayer.from_params(params) - self.pretrained_name = "bert-base-uncased" - self.pretrained = cached_transformers.get(self.pretrained_name, False) +def test_layer_with_cross_attention(layer_params): + layer_params["add_cross_attention"] = True - def test_can_construct_from_params(self): + transformer_layer = TransformerLayer.from_params(layer_params).eval() + assert hasattr(transformer_layer, "cross_attention") - transformer_layer = self.transformer_layer + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_layer( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert ( - transformer_layer.attention.self.num_attention_heads - == self.params_dict["num_attention_heads"] - ) - assert transformer_layer.attention.self.attention_head_size == int( - self.params_dict["hidden_size"] / self.params_dict["num_attention_heads"] - ) - assert ( - transformer_layer.attention.self.all_head_size - == self.params_dict["num_attention_heads"] - * transformer_layer.attention.self.attention_head_size - ) - assert transformer_layer.attention.self.query.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.key.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.value.in_features == self.params_dict["hidden_size"] - assert transformer_layer.attention.self.dropout.p == self.params_dict["attention_dropout"] - assert ( - transformer_layer.attention.output.dense.in_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.dense.out_features == self.params_dict["hidden_size"] - ) - assert ( - transformer_layer.attention.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) - assert transformer_layer.attention.output.dropout.p == self.params_dict["hidden_dropout"] +def get_layer_modules(): + params = copy.deepcopy(LAYER_PARAMS_DICT) + params["attention_probs_dropout_prob"] = params.pop("attention_dropout") + params["hidden_dropout_prob"] = params.pop("hidden_dropout") + params["hidden_act"] = params.pop("activation") - assert transformer_layer.intermediate.dense.in_features == self.params_dict["hidden_size"] - assert ( - transformer_layer.intermediate.dense.out_features - == self.params_dict["intermediate_size"] - ) + torch.manual_seed(1234) + yield "bert", BertLayer(BertConfig(**params)).eval() - assert transformer_layer.output.dense.in_features == self.params_dict["intermediate_size"] - assert transformer_layer.output.dense.out_features == self.params_dict["hidden_size"] + torch.manual_seed(1234) + yield "roberta", RobertaLayer(RobertaConfig(**params)).eval() - assert ( - transformer_layer.output.layer_norm.normalized_shape[0] - == self.params_dict["hidden_size"] - ) + torch.manual_seed(1234) + yield "electra", ElectraLayer(ElectraConfig(**params)).eval() - assert transformer_layer.output.dropout.p == self.params_dict["hidden_dropout"] - def test_forward_runs(self): - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.transformer_layer.forward(torch.randn(2, 3, 6), attention_mask=attention_mask) +@pytest.mark.parametrize("module_name, hf_module", get_layer_modules()) +def test_layer_matches_huggingface(layer_params, module_name, hf_module): + layer = TransformerLayer.from_params(layer_params).eval() + state_dict = layer._get_mapped_state_dict(hf_module.state_dict()) + layer.load_state_dict(state_dict) - with pytest.raises(AssertionError): - self.transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + torch.manual_seed(1234) + output = layer(hidden_states, attention_mask=attention_mask) + # We do this because bert, roberta, electra process the attention_mask at the model level. + attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 + torch.manual_seed(1234) + hf_output = hf_module(hidden_states, attention_mask=attention_mask_hf) - params = Params(params) + assert torch.allclose(output[0], hf_output[0]) - transformer_layer = TransformerLayer.from_params(params) - assert hasattr(transformer_layer, "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_layer.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize( + "pretrained_name, relevant_top_level_module", + [ + ("bert-base-cased", "bert"), + ("epwalsh/bert-xsmall-dummy", None), + ], +) +def test_layer_from_pretrained(pretrained_name, relevant_top_level_module): + torch.manual_seed(1234) + pretrained = cached_transformers.get(pretrained_name, False).eval() - transformer_layer_new = TransformerLayer.from_pretrained_module( - transformer_layer, source="allennlp" - ) + if "distilbert" in pretrained_name: + encoder = pretrained.transformer + else: + encoder = pretrained.encoder + # Hacky way to get a bert layer. + pretrained_module = list(encoder.layer.modules())[1] + + torch.manual_seed(1234) + module = TransformerLayer.from_pretrained_module( + pretrained_name, + relevant_module=None + if relevant_top_level_module is None + else f"{relevant_top_level_module}.encoder.layer.0", + ).eval() + + batch_size = 2 + seq_length = 15 + hidden_size = module.attention.self.query.in_features + + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 + + torch.manual_seed(1234) + output = module(hidden_states, attention_mask=attention_mask.squeeze())[0] + + torch.manual_seed(1234) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf)[0] - assert hasattr(transformer_layer_new, "cross_attention") - - def test_loading_from_pretrained_weights(self): - - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(self.pretrained.encoder.layer.modules()): - if i == 1: - break - - module = TransformerLayer.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - @pytest.mark.parametrize("module_name, hf_module", get_layer_modules(LAYER_PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - - layer = TransformerLayer.from_pretrained_module(hf_module) - - torch.manual_seed(1234) - output = layer.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) - - assert torch.allclose(output[0], hf_output[0]) - - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - "roberta-base", - ], + assert torch.allclose(output, hf_output, atol=1e-04) + + +def _load_pretrained(global_rank, world_size, gpu_id): + TransformerLayer.from_pretrained_module( + "epwalsh/bert-xsmall-dummy", ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - encoder = pretrained.transformer - else: - encoder = pretrained.encoder - # Hacky way to get a bert layer. - for i, pretrained_module in enumerate(encoder.layer.modules()): - if i == 1: - break - - pretrained_module = pretrained_module - - torch.manual_seed(1234) - module = TransformerLayer.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 2 - seq_len = 768 - dim = module.attention.self.query.in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp).expand( - batch_size, 12, seq_len, seq_len - ) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output, atol=1e-04) + + +@pytest.mark.parametrize("test_func", [_load_pretrained]) +def test_distributed(test_func): + run_distributed_test([-1, -1], func=test_func, start_method="spawn") diff --git a/tests/modules/transformer/transformer_module_test.py b/tests/modules/transformer/transformer_module_test.py index d5002f215ea..4018229c41d 100644 --- a/tests/modules/transformer/transformer_module_test.py +++ b/tests/modules/transformer/transformer_module_test.py @@ -1,74 +1,89 @@ import torch +from torch.nn import Parameter -from allennlp.common.testing import assert_equal_parameters +from allennlp.common.testing import assert_equal_parameters, assert_allclose from allennlp.modules.transformer import TransformerModule from allennlp.common.testing import AllenNlpTestCase class TestTransformerModule(AllenNlpTestCase): - def test_can_load_pretrained_weights(self): + def test_get_mapped_state_dict(self): class InternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.ff = torch.nn.Linear(inp, out) + self.p = Parameter(torch.randn(out, out)) + self.register_buffer("b", torch.randn(inp, inp)) def forward(self, x): - x = self.ff(x) + x = self.ff(x).matmul(self.p) return x class InternalNew(TransformerModule): + _pretrained_mapping = {"ff": "linear", "p": "param", "b": "buffer"} + def __init__(self, inp, out): super().__init__() self.linear = torch.nn.Linear(inp, out) - - def _construct_default_mapping(self, pretrained_module, source, mapping): - # return {"linear": "ff"} - return {"ff": "linear"} + self.param = Parameter(torch.randn(out, out)) + self.register_buffer("buffer", torch.randn(inp, inp)) def forward(self, x): - x = self.linear(x) + x = self.linear(x).matmul(self.param) return x class ExternalOld(torch.nn.Module): def __init__(self, inp, out): super().__init__() self.internal = InternalOld(inp, out) + self.p = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal(x) + x = self.internal(x).matmul(self.p) return x - class External(TransformerModule): - # _huggingface_mapping = {"internal_layer": "internal"} - _huggingface_mapping = {"internal": "internal_layer"} + class ExternalNew(TransformerModule): + _pretrained_mapping = {"internal": "internal_layer", "p": "param"} def __init__(self, inp, out): super().__init__() self.internal_layer = InternalNew(inp, out) + self.param = Parameter(torch.randn(out, out)) def forward(self, x): - x = self.internal_layer(x) + x = self.internal_layer(x).matmul(self.param) return x - iold = InternalOld(3, 5) - x = torch.randn(4, 3) - iold.forward(x) - inew = InternalNew(3, 5) - inew._load_from_pretrained_module(iold) - mapping = { - val: key - for key, val in inew._construct_default_mapping(iold, "huggingface", {}).items() - } - assert_equal_parameters(iold, inew, mapping=mapping) - eold = ExternalOld(3, 5) + state_dict_old = eold.state_dict() + + enew = ExternalNew(3, 5) + state_dict_new = enew._get_mapped_state_dict(state_dict_old) + assert set(state_dict_new.keys()) == set( + [ + "internal_layer.linear.weight", + "internal_layer.linear.bias", + "internal_layer.param", + "internal_layer.buffer", + "param", + ] + ) + + enew.load_state_dict(state_dict_new) + x = torch.randn(4, 3) - eold.forward(x) - - enew = External(3, 5) - enew._load_from_pretrained_module(eold) - mapping = { - val: key - for key, val in enew._construct_default_mapping(eold, "huggingface", {}).items() - } - assert_equal_parameters(eold, enew, mapping=mapping) + out_old = eold(x) + out_new = enew(x) + assert_allclose(out_old, out_new) + + assert_equal_parameters( + eold, + enew, + mapping={ + "internal_layer.linear.weight": "internal.ff.weight", + "internal_layer.linear.bias": "internal.ff.bias", + "internal_layer.param": "internal.p", + "internal_layer.buffer": "internal.b", + "param": "p", + }, + ) diff --git a/tests/modules/transformer/transformer_stack_test.py b/tests/modules/transformer/transformer_stack_test.py index 0481a407937..cf42f6c0f6d 100644 --- a/tests/modules/transformer/transformer_stack_test.py +++ b/tests/modules/transformer/transformer_stack_test.py @@ -1,20 +1,12 @@ import copy + import torch import pytest from allennlp.common import Params from allennlp.common import cached_transformers - -from allennlp.common.testing import assert_equal_parameters from allennlp.modules.transformer import TransformerStack, TransformerLayer -from allennlp.common.testing import AllenNlpTestCase -from transformers.models.bert.configuration_bert import BertConfig -from transformers.models.bert.modeling_bert import BertEncoder -from transformers.models.roberta.configuration_roberta import RobertaConfig -from transformers.models.roberta.modeling_roberta import RobertaEncoder -from transformers.models.electra.configuration_electra import ElectraConfig -from transformers.models.electra.modeling_electra import ElectraEncoder PARAMS_DICT = { "num_hidden_layers": 3, @@ -26,208 +18,93 @@ "activation": "relu", } - -def get_modules(params_dict): - modules = {} - params = copy.deepcopy(params_dict) - params["attention_probs_dropout_prob"] = params.pop("attention_dropout") - params["hidden_dropout_prob"] = params.pop("hidden_dropout") - - torch.manual_seed(1234) - hf_module = BertEncoder(BertConfig(**params)) - modules["bert"] = hf_module - - torch.manual_seed(1234) - hf_module = RobertaEncoder(RobertaConfig(**params)) - modules["roberta"] = hf_module - - torch.manual_seed(1234) - hf_module = ElectraEncoder(ElectraConfig(**params)) - modules["electra"] = hf_module - - return modules +SEED = 1234 -class TestTransformerStack(AllenNlpTestCase): - def setup_method(self): - super().setup_method() +@pytest.fixture +def params(): + return Params(copy.deepcopy(PARAMS_DICT)) - self.params_dict = { - "num_hidden_layers": 3, - "hidden_size": 6, - "intermediate_size": 3, - "num_attention_heads": 2, - "attention_dropout": 0.1, - "hidden_dropout": 0.2, - "activation": "relu", - } - params = Params(copy.deepcopy(self.params_dict)) +def test_transformer_stack_from_params(params): + torch.manual_seed(SEED) + transformer_stack = TransformerStack.from_params(params) - self.transformer_stack = TransformerStack.from_params(params) + # Make sure we have the right number of modules. + modules = dict(transformer_stack.named_modules()) + assert len(modules["layers"]) == PARAMS_DICT["num_hidden_layers"] - self.pretrained_name = "bert-base-uncased" + hidden_states = torch.randn(2, 3, 6) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - self.pretrained = cached_transformers.get(self.pretrained_name, False) + # Make sure forward pass can run. + torch.manual_seed(SEED) + output = transformer_stack.forward(hidden_states, attention_mask=attention_mask) - def test_can_construct_from_params(self): - - modules = dict(self.transformer_stack.named_modules()) - assert len(modules["layers"]) == self.params_dict["num_hidden_layers"] - - def test_forward_runs(self): - self.transformer_stack.forward(torch.randn(2, 3, 6), attention_mask=torch.randn(2, 3)) - - with pytest.raises(AssertionError): - self.transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=torch.randn(2, 3), - encoder_hidden_states=torch.randn(2, 3, 6), - ) - - def test_layer_same_as_params(self): - params = copy.deepcopy(self.params_dict) - num_hidden_layers = params.pop("num_hidden_layers") - # params = Params(params) - - torch.manual_seed(1234) - transformer_layer = TransformerLayer(**params) - transformer_stack_from_layer = TransformerStack(num_hidden_layers, transformer_layer) - torch.manual_seed(1234) - transformer_stack_from_params = TransformerStack(num_hidden_layers, **params) + # Make sure we get the same results when instantiating from a single layer. + torch.manual_seed(SEED) + layer_params = copy.deepcopy(PARAMS_DICT) + num_hidden_layers = layer_params.pop("num_hidden_layers") + transformer_layer = TransformerLayer(**layer_params) # type: ignore[arg-type] + transformer_stack_from_layer = TransformerStack( + num_hidden_layers, transformer_layer # type: ignore[arg-type] + ) - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + from_layer_output = transformer_stack_from_layer.forward( + hidden_states, attention_mask=attention_mask + ) - transformer_stack_from_layer.eval() - transformer_stack_from_params.eval() + assert torch.allclose(from_layer_output[0], output[0]) - torch.manual_seed(1234) - layer_output = transformer_stack_from_layer.forward( - hidden_states, attention_mask=attention_mask + # Make sure forward pass raises with bad input. + with pytest.raises(AssertionError): + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=torch.randn(2, 3), + encoder_hidden_states=torch.randn(2, 3, 6), ) - torch.manual_seed(1234) - params_output = transformer_stack_from_params.forward( - hidden_states, attention_mask=attention_mask - ) - assert torch.allclose(layer_output[0], params_output[0]) +def test_transformer_stack_with_cross_attention(params): + params["add_cross_attention"] = True - def test_cross_attention(self): - params = copy.deepcopy(self.params_dict) - params["add_cross_attention"] = True + transformer_stack = TransformerStack.from_params(params).eval() + modules = dict(transformer_stack.named_modules()) - params = Params(params) + assert hasattr(modules["layers.0"], "cross_attention") - transformer_stack = TransformerStack.from_params(params) - modules = dict(transformer_stack.named_modules()) + attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + transformer_stack.forward( + torch.randn(2, 3, 6), + attention_mask=attention_mask, + encoder_hidden_states=torch.randn(2, 3, 6), + ) - assert hasattr(modules["layers.0"], "cross_attention") - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) - transformer_stack.forward( - torch.randn(2, 3, 6), - attention_mask=attention_mask, - encoder_hidden_states=torch.randn(2, 3, 6), - ) +@pytest.mark.parametrize("pretrained_model_name", ["epwalsh/bert-xsmall-dummy", "bert-base-cased"]) +def test_loading_from_pretrained(pretrained_model_name): + transformer_stack = TransformerStack.from_pretrained_module(pretrained_model_name).eval() + pretrained_module = cached_transformers.get(pretrained_model_name, True).encoder.eval() - transformer_stack_new = TransformerStack.from_pretrained_module( - transformer_stack, source="allennlp" - ) + batch_size = 2 + seq_length = 15 + hidden_size = transformer_stack.layers[0]._hidden_size - new_modules = dict(transformer_stack_new.named_modules()) - assert hasattr(new_modules["layers.0"], "cross_attention") - - def test_loading_from_pretrained_weights(self): - pretrained_module = self.pretrained.encoder - module = TransformerStack.from_pretrained_module(pretrained_module) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping) - - def test_loading_partial_pretrained_weights(self): - - kwargs = TransformerStack._get_input_arguments(self.pretrained.encoder) - # The pretrained module has 12 bert layers, while the instance will have only 3. - kwargs["num_hidden_layers"] = 3 - transformer_stack = TransformerStack(**kwargs) - transformer_stack._load_from_pretrained_module(self.pretrained.encoder) - mapping = { - val: key - for key, val in transformer_stack._construct_default_mapping( - self.pretrained.encoder, "huggingface", {} - ).items() - } - assert_equal_parameters( - self.pretrained.encoder, - transformer_stack, - mapping, - ) + hidden_states = torch.randn(batch_size, seq_length, hidden_size) + attention_mask = torch.randint(0, 2, (batch_size, seq_length)) + attention_mask_hf = attention_mask[:, None, None, :] + attention_mask_hf = (1.0 - attention_mask_hf) * -10e5 - @pytest.mark.skip("Takes up too much memory") - @pytest.mark.parametrize("module_name, hf_module", get_modules(PARAMS_DICT).items()) - def test_forward_against_huggingface_outputs(self, module_name, hf_module): - hidden_states = torch.randn(2, 3, 6) - attention_mask = torch.tensor([[0, 1, 0], [1, 1, 0]]) + torch.manual_seed(SEED) + output = transformer_stack(hidden_states, attention_mask=attention_mask) - stack = TransformerStack.from_pretrained_module(hf_module) + torch.manual_seed(SEED) + hf_output = pretrained_module(hidden_states, attention_mask=attention_mask_hf) - torch.manual_seed(1234) - output = stack.forward(hidden_states, attention_mask=attention_mask) - # We do this because bert, roberta, electra process the attention_mask at the model level. - attention_mask_hf = (attention_mask == 0).view((2, 1, 1, 3)).expand(2, 2, 3, 3) * -10e5 - torch.manual_seed(1234) - hf_output = hf_module.forward(hidden_states, attention_mask=attention_mask_hf) + assert torch.allclose(output[0], hf_output[0]) - assert torch.allclose(output[0], hf_output[0]) - @pytest.mark.parametrize( - "pretrained_name", - [ - "bert-base-uncased", - ], - ) - def test_loading_from_pretrained_weights_using_model_name(self, pretrained_name): - - torch.manual_seed(1234) - pretrained = cached_transformers.get(pretrained_name, False) - - if "distilbert" in pretrained_name: - pretrained_module = pretrained.transformer - else: - pretrained_module = pretrained.encoder - - torch.manual_seed(1234) - module = TransformerStack.from_pretrained_module(pretrained_name) - mapping = { - val: key - for key, val in module._construct_default_mapping( - pretrained_module, "huggingface", {} - ).items() - } - assert_equal_parameters(pretrained_module, module, mapping=mapping) - - batch_size = 1 - seq_len = 768 - dim = dict(module.named_modules())["layers.0.attention.self.query"].in_features - hidden_states = torch.randn(batch_size, seq_len, dim) - attention_mask = torch.randint(0, 2, (batch_size, seq_len)) - mask_reshp = (batch_size, 1, 1, dim) - attention_mask_hf = (attention_mask == 0).view(mask_reshp) - attention_mask_hf = attention_mask_hf.expand(batch_size, 12, seq_len, seq_len) * -10e5 - - # setting to eval mode to avoid non-deterministic dropout. - module = module.eval() - pretrained_module = pretrained_module.eval() - - torch.manual_seed(1234) - output = module.forward(hidden_states, attention_mask=attention_mask.squeeze())[0] - torch.manual_seed(1234) - hf_output = pretrained_module.forward(hidden_states, attention_mask=attention_mask_hf)[0] - - assert torch.allclose(output, hf_output) +def test_loading_partial_pretrained_weights(): + # The pretrained module has 12 bert layers, while the instance will have only 3. + TransformerStack.from_pretrained_module("bert-base-cased", num_hidden_layers=3, strict=False) diff --git a/tests/nn/util_test.py b/tests/nn/util_test.py index 7ca660ed04d..73a9952a11f 100644 --- a/tests/nn/util_test.py +++ b/tests/nn/util_test.py @@ -9,7 +9,7 @@ from flaky import flaky from allennlp.common.checks import ConfigurationError -from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.testing import AllenNlpTestCase, run_distributed_test from allennlp.common.util import sanitize from allennlp.data import Token, Vocabulary from allennlp.data.fields import TextField @@ -1730,8 +1730,6 @@ def test_dist_reduce_sum(self): ret_value = util.dist_reduce_sum(value) assert (ret_value == value).all().item() - from allennlp.common.testing.distributed_test import run_distributed_test - func_kwargs = {"value": [torch.Tensor([1, 2, 3]), torch.Tensor([4, 5, 6])]} desired_values = torch.Tensor([5, 7, 9]) @@ -1761,3 +1759,79 @@ def global_distributed_func( output = function(**kwargs) assert (output == desired_values).all().item() + + +class DistributedFixtureModel(torch.nn.Module): + """ + Fake model for testing `load_state_dict_distributed()`. + """ + + def __init__(self): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.custom_submodule = DistributedFixtureSubmodule() + self.custom_sharded_submodule = DistributedFixtureSubmodule(sharded=True) + self.linear_submodule = torch.nn.Linear(3, 5) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +class DistributedFixtureSubmodule(torch.nn.Module): + def __init__(self, sharded: bool = False): + super().__init__() + self.direct_param = torch.nn.Parameter(torch.randn(3, 5)) + self.register_buffer("direct_buffer", torch.randn(2, 2)) + self.linear_submodule = torch.nn.Linear(3, 5) + if sharded: + setattr(self, util._MODULE_SHARDED_FLAG, True) + + def forward(self, x): + # This doesn't matter, we're not going to actually use it. + pass + + +def _dist_load_ok(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + missing_keys, unexpected_keys = util.load_state_dict_distributed(model, state_dict) + assert not missing_keys + assert not unexpected_keys + + +def _dist_load_with_errors(global_rank, world_size, gpu_id): + model = DistributedFixtureModel() + state_dict = None if global_rank != 0 else model.state_dict() + _missing_keys = [ + "direct_buffer", + "custom_submodule.linear_submodule.bias", + "custom_submodule.direct_param", + "custom_sharded_submodule.linear_submodule.bias", + "custom_sharded_submodule.direct_buffer", + ] + _unexpected_keys = [ + "not_a_parameter", + "custom_submodule.not_a_parameter", + "custom_submodule.linear.not_a_parameter", + "custom_sharded_submodule.not_a_parameter", + "custom_sharded_submodule.linear.not_a_parameter", + "not_even_submodule.not_a_parameter", + ] + if state_dict is not None: + for key in _missing_keys: + del state_dict[key] + for key in _unexpected_keys: + state_dict[key] = torch.randn(2, 2) + missing_keys, unexpected_keys = util.load_state_dict_distributed( + model, state_dict, strict=False + ) + if global_rank == 0: + assert set(missing_keys) == set(_missing_keys) + assert set(unexpected_keys) == set(_unexpected_keys) + + +@pytest.mark.parametrize("test_func", [_dist_load_ok, _dist_load_with_errors]) +def test_load_state_dict_distributed(test_func): + run_distributed_test([-1, -1], func=test_func)