From c7e44fc2c52d67e56ba4b36fcea588bd23d633f2 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Wed, 20 Oct 2021 20:48:03 +0200 Subject: [PATCH] embeddings: fix attention mask for special Transformer architectures --- flair/embeddings/token.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 22fe29198a..4bf77453d8 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -1011,14 +1011,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: return_tensors='pt', ) + model_kwargs = {} input_ids = batch_encoding['input_ids'].to(flair.device) - attention_mask = batch_encoding['attention_mask'].to(flair.device) + + # Models such as FNet do not have an attention_mask + if 'attention_mask' in batch_encoding: + model_kwargs['attention_mask'] = batch_encoding['attention_mask'].to(flair.device) # determine which sentence was split into how many parts sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \ else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist() - model_kwargs = {} # set language IDs for XLM-style transformers if self.use_lang_emb: model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) @@ -1029,7 +1032,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: model_kwargs["langs"][s_id][:sequence_length] = lang_id # put encoded batch through transformer model to get all hidden states of all encoder layers - hidden_states = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)[-1] + hidden_states = self.model(input_ids, **model_kwargs)[-1] # make the tuple a tensor; makes working with it easier. hidden_states = torch.stack(hidden_states)