Skip to content

Commit

Permalink
Merge pull request #3608 from flairNLP/add_compability_to_torch_2.6
Browse files Browse the repository at this point in the history
add compatibility to torch 2.6
  • Loading branch information
alanakbik authored Feb 1, 2025
2 parents 005ec45 + 087b74e commit e00e0ff
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 7 deletions.
3 changes: 1 addition & 2 deletions flair/embeddings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,9 @@ def _add_embeddings_internal(self, sentences: list[Sentence]):

lengths: list[int] = [len(sentence.tokens) for sentence in sentences]
padding_length: int = max(max(lengths), self.min_sequence_length)

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * padding_length,
dtype=self.convs[0].weight.dtype,
dtype=cast(torch.nn.Conv1d, self.convs[0]).weight.dtype,
device=flair.device,
)

Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
word = token.text if self.field is None else token.get_label(self.field).value

if word.strip() == "":
ids = [self.spm.vocab_size(), self.embedder.spm.vocab_size()]
ids = [self.spm.vocab_size(), self.spm.vocab_size()]
else:
if self.do_preproc:
word = self._preprocess(word)
Expand Down
2 changes: 1 addition & 1 deletion flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,4 @@ def load_torch_state(model_file: str) -> dict[str, typing.Any]:
# to load models on some Mac/Windows setups
# see /~https://github.com/zalandoresearch/flair/issues/351
f = load_big_file(model_file)
return torch.load(f, map_location="cpu")
return torch.load(f, map_location="cpu", weights_only=False)
4 changes: 2 additions & 2 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,8 @@ def push_to_hub(
self.save(local_model_path)

# Determine if model card already exists
info = model_info(repo_id, use_auth_token=token)
write_readme = all(f.rfilename != "README.md" for f in info.siblings)
info = model_info(repo_id, token=token)
write_readme = info.siblings is None or all(f.rfilename != "README.md" for f in info.siblings)

# Generate and save model card
if write_readme:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ scikit-learn>=1.0.2
segtok>=1.5.11
sqlitedict>=2.0.0
tabulate>=0.8.10
torch>=1.5.0,!=1.8
torch>=1.13.1
tqdm>=4.63.0
transformer-smaller-training-vocab>=0.2.3
transformers[sentencepiece]>=4.25.0,<5.0.0
Expand Down

0 comments on commit e00e0ff

Please sign in to comment.