Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embeddings: fix attention mask for special Transformer architectures #2485

Merged
merged 1 commit into from
Oct 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,14 +1011,17 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return_tensors='pt',
)

model_kwargs = {}
input_ids = batch_encoding['input_ids'].to(flair.device)
attention_mask = batch_encoding['attention_mask'].to(flair.device)

# Models such as FNet do not have an attention_mask
if 'attention_mask' in batch_encoding:
model_kwargs['attention_mask'] = batch_encoding['attention_mask'].to(flair.device)

# determine which sentence was split into how many parts
sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \
else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist()

model_kwargs = {}
# set language IDs for XLM-style transformers
if self.use_lang_emb:
model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype)
Expand All @@ -1029,7 +1032,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
model_kwargs["langs"][s_id][:sequence_length] = lang_id

# put encoded batch through transformer model to get all hidden states of all encoder layers
hidden_states = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)[-1]
hidden_states = self.model(input_ids, **model_kwargs)[-1]
# make the tuple a tensor; makes working with it easier.
hidden_states = torch.stack(hidden_states)

Expand Down