Skip to content

Commit

Permalink
Merge pull request #2975 from flairNLP/profiler-optimizations
Browse files Browse the repository at this point in the history
Small speed optimizations
  • Loading branch information
alanakbik authored Oct 30, 2022
2 parents d099d60 + 7d0482e commit 56ac673
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
31 changes: 16 additions & 15 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,19 @@ def set_embedding(self, name: str, vector: torch.Tensor):
self._embeddings[name] = vector

def get_embedding(self, names: Optional[List[str]] = None) -> torch.Tensor:
embeddings = self.get_each_embedding(names)
# if one embedding name, directly return it
if names and len(names) == 1:
if names[0] in self._embeddings:
return self._embeddings[names[0]].to(flair.device)
else:
return torch.tensor([], device=flair.device)

# if multiple embedding names, concatenate them
embeddings = self.get_each_embedding(names)
if embeddings:
return torch.cat(embeddings, dim=0)

return torch.tensor([], device=flair.device)
else:
return torch.tensor([], device=flair.device)

def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> List[torch.Tensor]:
embeddings = []
Expand Down Expand Up @@ -882,7 +889,7 @@ def to_tagged_string(self, main_label=None) -> str:

@property
def text(self):
return "".join([t.text + t.whitespace_after * " " for t in self.tokens])
return self.to_original_text()

def to_tokenized_string(self) -> str:

Expand Down Expand Up @@ -932,17 +939,11 @@ def infer_space_after(self):
return self

def to_original_text(self) -> str:
str = ""
pos = 0
for t in self.tokens:
while t.start_pos > pos:
str += " "
pos += 1

str += t.text
pos += len(t.text)

return str
# if sentence has no tokens, return empty string
if len(self) == 0:
return ""
# otherwise, return concatenation of tokens with the correct offsets
return self[0].start_pos * " " + "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip()

def to_dict(self, tag_type: str = None):
labels = []
Expand Down
11 changes: 7 additions & 4 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, embeddings: List[TokenEmbeddings]):
self.add_module(f"list_embedding_{str(i)}", embedding)

self.name: str = "Stack"
self.__names = [name for embedding in self.embeddings for name in embedding.get_names()]

self.static_embeddings: bool = True

self.__embedding_type: str = embeddings[0].embedding_type
Expand Down Expand Up @@ -115,10 +117,11 @@ def get_names(self) -> List[str]:
"""Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of
this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding).
Then, the list contains the names of all embeddings in the stack."""
names = []
for embedding in self.embeddings:
names.extend(embedding.get_names())
return names
# make compatible with serialized models
if "__names" not in self.__dict__:
self.__names = [name for embedding in self.embeddings for name in embedding.get_names()]

return self.__names

def get_named_embeddings_dict(self) -> Dict:

Expand Down

0 comments on commit 56ac673

Please sign in to comment.