From c861fdb2ce85c6f63d31bf61c208fc50d97f9d97 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 11 Jan 2022 17:54:43 +0100 Subject: [PATCH] Enable batch mode for SAS cross encoders (#1987) * enable batch mode for sas cross encoders * fix mypy * comment on top_1 values added --- haystack/nodes/evaluator/evaluator.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/haystack/nodes/evaluator/evaluator.py b/haystack/nodes/evaluator/evaluator.py index 5604cb9e1d..21533607c0 100644 --- a/haystack/nodes/evaluator/evaluator.py +++ b/haystack/nodes/evaluator/evaluator.py @@ -388,24 +388,31 @@ def semantic_answer_similarity(predictions: List[List[str]], # Compute similarities top_1_sas = [] top_k_sas = [] + lengths: List[Tuple[int,int]] = [] # Based on Modelstring we can load either Bi-Encoders or Cross Encoders. # Similarity computation changes for both approaches if cross_encoder_used: - model = CrossEncoder(sas_model_name_or_path) - for preds, labels in zip (predictions,gold_labels): - # TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards - grid = [] + model = CrossEncoder(sas_model_name_or_path) + grid = [] + for preds, labels in zip (predictions,gold_labels): for p in preds: for l in labels: grid.append((p,l)) - scores = model.predict(grid) - top_1_sas.append(np.max(scores[:len(labels)])) - top_k_sas.append(np.max(scores)) + lengths.append((len(preds), len(labels))) + scores = model.predict(grid) + + current_position = 0 + for len_p, len_l in lengths: + scores_window = scores[current_position:current_position+len_p*len_l] + # Per predicted doc there are len_l entries comparing it to all len_l labels. + # So to only consider the first doc we have to take the first len_l entries + top_1_sas.append(np.max(scores_window[:len_l])) + top_k_sas.append(np.max(scores_window)) + current_position += len_p*len_l else: # For Bi-encoders we can flatten predictions and labels into one list model = SentenceTransformer(sas_model_name_or_path) - lengths: List[Tuple[int,int]] = [] all_texts: List[str] = [] for p, l in zip(predictions, gold_labels): # type: ignore # TODO potentially exclude (near) exact matches from computations @@ -417,7 +424,7 @@ def semantic_answer_similarity(predictions: List[List[str]], # then select which embeddings will be used for similarity computations current_position = 0 - for i, (len_p, len_l) in enumerate(lengths): + for len_p, len_l in lengths: pred_embeddings = embeddings[current_position:current_position + len_p, :] current_position += len_p label_embeddings = embeddings[current_position:current_position + len_l, :]