Skip to content

Commit

Permalink
Add "no answer" aggregation to Transformersreader (#259)
Browse files Browse the repository at this point in the history
* Add no answer aggregation

* Change to covariant type annotation

* Remove n_best_per_passage from transformersreader
  • Loading branch information
Timoeller authored Aug 6, 2020
1 parent 89dcfed commit d9e8b52
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 46 deletions.
31 changes: 30 additions & 1 deletion haystack/reader/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from scipy.special import expit
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Sequence

from haystack.database.base import Document

Expand All @@ -9,3 +11,30 @@ class BaseReader(ABC):
@abstractmethod
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
pass

@staticmethod
def _calc_no_answer(no_ans_gaps: Sequence[float], best_score_answer: float):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap is a list of this most significant difference per document
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
# all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap

no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None,
"meta": None,}
return no_ans_prediction, max_no_ans_gap
28 changes: 1 addition & 27 deletions haystack/reader/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
context_window_size: int = 150,
batch_size: int = 50,
use_gpu: bool = True,
no_ans_boost: Optional[int] = None,
no_ans_boost: Optional[float] = None,
top_k_per_candidate: int = 3,
top_k_per_sample: int = 1,
num_processes: Optional[int] = None,
Expand Down Expand Up @@ -446,32 +446,6 @@ def _check_no_answer(c: QACandidate):
return False


@staticmethod
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to one specific document
# - a "no answer" score is related to all input documents
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: a model switching from predicting an answer to "no answer" (or vice versa).
# No_ans_gap coming from FARM mean how much no_ans_boost should change to switch predictions
no_ans_gaps = np.array(no_ans_gaps)
max_no_ans_gap = np.max(no_ans_gaps)
# all passages "no answer" as top score
if (np.sum(no_ans_gaps < 0) == len(no_ans_gaps)): # type: ignore
no_ans_score = best_score_answer - max_no_ans_gap # max_no_ans_gap is negative, so it increases best pos score
else: # case: at least one passage predicts an answer (positive no_ans_gap)
no_ans_score = best_score_answer - max_no_ans_gap

no_ans_prediction = {"answer": None,
"score": no_ans_score,
"probability": float(expit(np.asarray(no_ans_score) / 8)), # just a pseudo prob for now
"context": None,
"offset_start": 0,
"offset_end": 0,
"document_id": None}
return no_ans_prediction, max_no_ans_gap

def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None):
documents = []
for text in texts:
Expand Down
59 changes: 42 additions & 17 deletions haystack/reader/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
tokenizer: str = "distilbert-base-uncased",
context_window_size: int = 30,
use_gpu: int = 0,
n_best_per_passage: int = 2,
top_k_per_candidate: int = 4,
no_answer: bool = True
):
"""
Expand All @@ -40,14 +40,17 @@ def __init__(
The context usually helps users to understand if the answer really makes sense.
:param use_gpu: < 0 -> use cpu
>= 0 -> ordinal of the gpu to use
:param n_best_per_passage: num of best answers to take into account for each passage
:param top_k_per_candidate: How many answers to extract for each candidate doc that is coming from the retriever (might be a long text).
Note: - This is not the number of "final answers" you will receive
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
- Can includes no_answer in the sorted list of predictions
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
False -> otherwise
"""
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
self.context_window_size = context_window_size
self.n_best_per_passage = n_best_per_passage
self.top_k_per_candidate = top_k_per_candidate
self.no_answer = no_answer

# TODO context_window_size behaviour different from behavior in FARMReader
Expand Down Expand Up @@ -80,27 +83,49 @@ def predict(self, question: str, documents: List[Document], top_k: Optional[int]
"""
# get top-answers for each candidate passage
answers = []
no_ans_gaps = []
best_overall_score = 0
for doc in documents:
query = {"context": doc.text, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer)
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.no_answer)
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
if type(predictions) == dict:
predictions = [predictions]
# assemble and format all answers
for pred in predictions:
context_start = max(0, pred["start"] - self.context_window_size)
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
answers.append({
"answer": pred["answer"],
"context": doc.text[context_start:context_end],
"offset_start": pred["start"],
"offset_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": doc.id,
"meta": doc.meta
})

best_doc_score = 0
# because we cannot ensure a "no answer" prediction coming back from transformers we initialize it here with 0
no_ans_doc_score = 0
# TODO add no answer bias on haystack side after getting "no answer" scores from transformers
for pred in predictions:
if pred["answer"]:
if pred["score"] > best_doc_score:
best_doc_score = pred["score"]
context_start = max(0, pred["start"] - self.context_window_size)
context_end = min(len(doc.text), pred["end"] + self.context_window_size)
answers.append({
"answer": pred["answer"],
"context": doc.text[context_start:context_end],
"offset_start": pred["start"],
"offset_end": pred["end"],
"probability": pred["score"],
"score": None,
"document_id": doc.id,
"meta": doc.meta
})
else:
no_ans_doc_score = pred["score"]

if best_doc_score > best_overall_score:
best_overall_score = best_doc_score

no_ans_gaps.append(no_ans_doc_score - best_doc_score)

# Calculate the score for predicting "no answer", relative to our best positive answer score
no_ans_prediction, max_no_ans_gap = self._calc_no_answer(no_ans_gaps, best_overall_score)

if self.no_answer:
answers.append(no_ans_prediction)
# sort answers by their `probability` and select top-k
answers = sorted(
answers, key=lambda k: k["probability"], reverse=True
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def no_answer_reader(request):
if request.param == "transformers":
return TransformersReader(model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2",
use_gpu=-1, n_best_per_passage=5)
use_gpu=-1, top_k_per_candidate=5)


@pytest.fixture()
Expand Down

0 comments on commit d9e8b52

Please sign in to comment.