Skip to content

Commit

Permalink
Add Tapas reader with scores (#1997)
Browse files Browse the repository at this point in the history
* Add Tapas reader with scores

* Adapt possible answer spans

* Add latest docstring and tutorial changes

* Remove unused imports

* Adapt scoring

* Add latest docstring and tutorial changes

* Fix mypy

* Infer model architecture from config

* Adapt answer score calculation

* Add latest docstring and tutorial changes

* Fix mypy

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
bogdankostic and github-actions[bot] authored Jan 31, 2022
1 parent ee6b8d0 commit bbb65a1
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 50 deletions.
16 changes: 14 additions & 2 deletions docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -630,26 +630,38 @@ answer = prediction["answers"][0].answer # "10 june 1996"
#### \_\_init\_\_

```python
| __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
| __init__(model_name_or_path: str = "google/tapas-base-finetuned-wtq", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, top_k_per_candidate: int = 3, return_no_answer: bool = False, max_seq_len: int = 256)
```

Load a TableQA model from Transformers.
Available models include:

- ``'google/tapas-base-finetuned-wtq`'``
- ``'google/tapas-base-finetuned-wikisql-supervised``'
- ``'deepset/tapas-large-nq-hn-reader'``
- ``'deepset/tapas-large-nq-reader'``

See https://huggingface.co/models?pipeline_tag=table-question-answering
for full list of available TableQA models.

The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
over multiple cells. The returned answers are sorted first by a general table score and then by answer span
scores.
All the other models can handle aggregation questions, but don't provide reasonable confidence scores.

**Arguments**:

- `model_name_or_path`: Directory of a saved model or the name of a public model e.g.
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
or commit hash.
- `tokenizer`: Name of the tokenizer (usually the same as model)
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
- `top_k`: The maximum number of answers to return
- `top_k_per_candidate`: How many answers to extract for each candidate table that is coming from
the retriever.
- `return_no_answer`: Whether to include no_answer predictions in the results.
(Only applicable with nq-reader models.)
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.
Expand Down
252 changes: 204 additions & 48 deletions haystack/nodes/reader/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import pandas as pd
from quantulum3 import parser
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, BatchEncoding
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
BatchEncoding, TapasModel, TapasConfig
from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel

from haystack.schema import Document, Answer, Span
from haystack.nodes.reader.base import BaseReader
Expand Down Expand Up @@ -49,6 +51,8 @@ def __init__(
tokenizer: Optional[str] = None,
use_gpu: bool = True,
top_k: int = 10,
top_k_per_candidate: int = 3,
return_no_answer: bool = False,
max_seq_len: int = 256,
):
"""
Expand All @@ -57,31 +61,54 @@ def __init__(
- ``'google/tapas-base-finetuned-wtq`'``
- ``'google/tapas-base-finetuned-wikisql-supervised``'
- ``'deepset/tapas-large-nq-hn-reader'``
- ``'deepset/tapas-large-nq-reader'``
See https://huggingface.co/models?pipeline_tag=table-question-answering
for full list of available TableQA models.
The nq-reader models are able to provide confidence scores, but cannot handle questions that need aggregation
over multiple cells. The returned answers are sorted first by a general table score and then by answer span
scores.
All the other models can handle aggregation questions, but don't provide reasonable confidence scores.
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
See https://huggingface.co/models?pipeline_tag=table-question-answering for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name,
or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
:param top_k: The maximum number of answers to return
:param top_k_per_candidate: How many answers to extract for each candidate table that is coming from
the retriever.
:param return_no_answer: Whether to include no_answer predictions in the results.
(Only applicable with nq-reader models.)
:param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.
"""
# Save init parameters to enable export of component config as YAML
self.set_config(model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
use_gpu=use_gpu, top_k=top_k, top_k_per_candidate=top_k_per_candidate,
return_no_answer=return_no_answer, max_seq_len=max_seq_len)

self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
config = TapasConfig.from_pretrained(model_name_or_path)
if config.architectures[0] == "TapasForScoredQA":
self.model = self.TapasForScoredQA.from_pretrained(model_name_or_path, revision=model_version)
else:
self.model = TapasForQuestionAnswering.from_pretrained(model_name_or_path, revision=model_version)
self.model.to(str(self.devices[0]))

if tokenizer is None:
self.tokenizer = TapasTokenizer.from_pretrained(model_name_or_path)
else:
self.tokenizer = TapasTokenizer.from_pretrained(tokenizer)

self.top_k = top_k
self.top_k_per_candidate = top_k_per_candidate
self.max_seq_len = max_seq_len
self.return_no_answers = False
self.return_no_answer = return_no_answer

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
"""
Expand All @@ -102,6 +129,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
top_k = self.top_k

answers = []
no_answer_score = 1.0
for document in documents:
if document.content_type != "table":
logger.warning(f"Skipping document with id {document.id} in TableReader, as it is not of type table.")
Expand All @@ -115,68 +143,175 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
return_tensors="pt",
truncation=True)
inputs.to(self.devices[0])
# Forward query and table through model and convert logits to predictions
outputs = self.model(**inputs)
inputs.to("cpu")
if self.model.config.num_aggregation_labels > 0:
aggregation_logits = outputs.logits_aggregation.cpu().detach()
else:
aggregation_logits = None

predicted_output = self.tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.cpu().detach(),
aggregation_logits
)
if len(predicted_output) == 1:
predicted_answer_coordinates = predicted_output[0]
else:
predicted_answer_coordinates, predicted_aggregation_indices = predicted_output
if isinstance(self.model, TapasForQuestionAnswering):
current_answer = self._predict_tapas_for_qa(inputs, document)
answers.append(current_answer)
elif isinstance(self.model, self.TapasForScoredQA):
current_answers, current_no_answer_score = self._predict_tapas_for_scored_qa(inputs, document)
answers.extend(current_answers)
if current_no_answer_score < no_answer_score:
no_answer_score = current_no_answer_score

if self.return_no_answer and isinstance(self.model, self.TapasForScoredQA):
answers.append(Answer(
answer="",
type="extractive",
score=no_answer_score,
context=None,
offsets_in_context=[Span(start=0, end=0)],
offsets_in_document=[Span(start=0, end=0)],
document_id=None,
meta=None
))
answers = sorted(answers, reverse=True)
answers = answers[:top_k]

# Get cell values
current_answer_coordinates = predicted_answer_coordinates[0]
current_answer_cells = []
for coordinate in current_answer_coordinates:
current_answer_cells.append(table.iat[coordinate])
results = {"query": query,
"answers": answers}

# Get aggregation operator
if self.model.config.aggregation_labels is not None:
current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]]
else:
current_aggregation_operator = "NONE"

# Calculate answer score
current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates)
return results

if current_aggregation_operator == "NONE":
answer_str = ", ".join(current_answer_cells)
else:
answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)
def _predict_tapas_for_qa(self, inputs: BatchEncoding, document: Document) -> Answer:
table: pd.DataFrame = document.content

# Forward query and table through model and convert logits to predictions
outputs = self.model(**inputs)
inputs.to("cpu")
if self.model.config.num_aggregation_labels > 0:
aggregation_logits = outputs.logits_aggregation.cpu().detach()
else:
aggregation_logits = None

predicted_output = self.tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.cpu().detach(),
aggregation_logits
)
if len(predicted_output) == 1:
predicted_answer_coordinates = predicted_output[0]
else:
predicted_answer_coordinates, predicted_aggregation_indices = predicted_output

# Get cell values
current_answer_coordinates = predicted_answer_coordinates[0]
current_answer_cells = []
for coordinate in current_answer_coordinates:
current_answer_cells.append(table.iat[coordinate])

# Get aggregation operator
if self.model.config.aggregation_labels is not None:
current_aggregation_operator = self.model.config.aggregation_labels[predicted_aggregation_indices[0]]
else:
current_aggregation_operator = "NONE"

# Calculate answer score
current_score = self._calculate_answer_score(outputs.logits.cpu().detach(), inputs, current_answer_coordinates)

answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, table)
if current_aggregation_operator == "NONE":
answer_str = ", ".join(current_answer_cells)
else:
answer_str = self._aggregate_answers(current_aggregation_operator, current_answer_cells)

answer_offsets = self._calculate_answer_offsets(current_answer_coordinates, document.content)

answer = Answer(
answer=answer_str,
type="extractive",
score=current_score,
context=document.content,
offsets_in_document=answer_offsets,
offsets_in_context=answer_offsets,
document_id=document.id,
meta={"aggregation_operator": current_aggregation_operator,
"answer_cells": current_answer_cells}
)

return answer

def _predict_tapas_for_scored_qa(self, inputs: BatchEncoding, document: Document) -> Tuple[List[Answer], float]:
table: pd.DataFrame = document.content

# Forward pass through model
outputs = self.model.tapas(**inputs)

# Get general table score
table_score = self.model.classifier(outputs.pooler_output)
table_score_softmax = torch.nn.functional.softmax(table_score, dim=1)
table_relevancy_prob = table_score_softmax[0][1].item()

# Get possible answer spans
token_types = [
"segment_ids",
"column_ids",
"row_ids",
"prev_labels",
"column_ranks",
"inv_column_ranks",
"numeric_relations",
]
row_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("row_ids")].tolist()[0]
column_ids: List[int] = inputs.token_type_ids[:, :, token_types.index("column_ids")].tolist()[0]

possible_answer_spans: List[Tuple[int, int, int, int]] = [] # List of tuples: (row_idx, col_idx, start_token, end_token)
current_start_idx = -1
current_column_id = -1
for idx, (row_id, column_id) in enumerate(zip(row_ids, column_ids)):
if row_id == 0 or column_id == 0:
continue
# Beginning of new cell
if column_id != current_column_id:
if current_start_idx != -1:
possible_answer_spans.append(
(row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, idx-1)
)
current_start_idx = idx
current_column_id = column_id
possible_answer_spans.append(
(row_ids[current_start_idx]-1, column_ids[current_start_idx]-1, current_start_idx, len(row_ids)-1)
)

# Concat logits of start token and end token of possible answer spans
sequence_output = outputs.last_hidden_state
concatenated_logits = []
for possible_span in possible_answer_spans:
start_token_logits = sequence_output[0, possible_span[2], :]
end_token_logits = sequence_output[0, possible_span[3], :]
concatenated_logits.append(torch.cat((start_token_logits, end_token_logits)))
concatenated_logit_tensors = torch.unsqueeze(torch.stack(concatenated_logits), dim=0)

# Calculate score for each possible span
span_logits = torch.einsum("bsj,j->bs", concatenated_logit_tensors, self.model.span_output_weights) \
+ self.model.span_output_bias
span_logits_softmax = torch.nn.functional.softmax(span_logits, dim=1)

top_k_answer_spans = torch.topk(span_logits[0], min(self.top_k_per_candidate, len(possible_answer_spans)))

answers = []
for answer_span_idx in top_k_answer_spans.indices:
current_answer_span = possible_answer_spans[answer_span_idx]
answer_str = table.iat[current_answer_span[:2]]
answer_offsets = self._calculate_answer_offsets([current_answer_span[:2]], document.content)
# As the general table score is more important for the final score, it is double weighted.
current_score = ((2 * table_relevancy_prob) + span_logits_softmax[0, answer_span_idx].item()) / 3

answers.append(
Answer(
answer=answer_str,
type="extractive",
score=current_score,
context=table,
context=document.content,
offsets_in_document=answer_offsets,
offsets_in_context=answer_offsets,
document_id=document.id,
meta={"aggregation_operator": current_aggregation_operator,
"answer_cells": current_answer_cells}
meta={"aggregation_operator": "NONE",
"answer_cells": table.iat[current_answer_span[:2]]}
)
)

# Sort answers by score and select top-k answers
answers = sorted(answers, reverse=True)
answers = answers[:top_k]

results = {"query": query,
"answers": answers}
no_answer_score = 1 - table_relevancy_prob

return results
return answers, no_answer_score

def _calculate_answer_score(self, logits: torch.Tensor, inputs: BatchEncoding,
answer_coordinates: List[Tuple[int, int]]) -> float:
Expand Down Expand Up @@ -253,6 +388,27 @@ def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table:
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
raise NotImplementedError("Batch prediction not yet available in TableReader.")

class TapasForScoredQA(TapasPreTrainedModel):

def __init__(self, config):
super().__init__(config)

# base model
self.tapas = TapasModel(config)

# dropout (only used when training)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

# answer selection head
self.span_output_weights = torch.nn.Parameter(torch.zeros(2 * config.hidden_size))
self.span_output_bias = torch.nn.Parameter(torch.zeros([]))

# table scoring head
self.classifier = torch.nn.Linear(config.hidden_size, 2)

# Initialize weights
self.init_weights()


class RCIReader(BaseReader):
"""
Expand Down

0 comments on commit bbb65a1

Please sign in to comment.