Skip to content

Commit

Permalink
Enable batch mode for SAS cross encoders (#1987)
Browse files Browse the repository at this point in the history
* enable batch mode for sas cross encoders

* fix mypy

* comment on top_1 values added
  • Loading branch information
tstadel authored Jan 11, 2022
1 parent 9c3d9b4 commit c861fdb
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions haystack/nodes/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,24 +388,31 @@ def semantic_answer_similarity(predictions: List[List[str]],
# Compute similarities
top_1_sas = []
top_k_sas = []
lengths: List[Tuple[int,int]] = []

# Based on Modelstring we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
model = CrossEncoder(sas_model_name_or_path)
for preds, labels in zip (predictions,gold_labels):
# TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards
grid = []
model = CrossEncoder(sas_model_name_or_path)
grid = []
for preds, labels in zip (predictions,gold_labels):
for p in preds:
for l in labels:
grid.append((p,l))
scores = model.predict(grid)
top_1_sas.append(np.max(scores[:len(labels)]))
top_k_sas.append(np.max(scores))
lengths.append((len(preds), len(labels)))
scores = model.predict(grid)

current_position = 0
for len_p, len_l in lengths:
scores_window = scores[current_position:current_position+len_p*len_l]
# Per predicted doc there are len_l entries comparing it to all len_l labels.
# So to only consider the first doc we have to take the first len_l entries
top_1_sas.append(np.max(scores_window[:len_l]))
top_k_sas.append(np.max(scores_window))
current_position += len_p*len_l
else:
# For Bi-encoders we can flatten predictions and labels into one list
model = SentenceTransformer(sas_model_name_or_path)
lengths: List[Tuple[int,int]] = []
all_texts: List[str] = []
for p, l in zip(predictions, gold_labels): # type: ignore
# TODO potentially exclude (near) exact matches from computations
Expand All @@ -417,7 +424,7 @@ def semantic_answer_similarity(predictions: List[List[str]],

# then select which embeddings will be used for similarity computations
current_position = 0
for i, (len_p, len_l) in enumerate(lengths):
for len_p, len_l in lengths:
pred_embeddings = embeddings[current_position:current_position + len_p, :]
current_position += len_p
label_embeddings = embeddings[current_position:current_position + len_l, :]
Expand Down

0 comments on commit c861fdb

Please sign in to comment.