Skip to content

Commit

Permalink
Dense Passage Retriever (Inference) (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
tholor authored Jun 30, 2020
1 parent 27b8c98 commit 07ecfb6
Show file tree
Hide file tree
Showing 24 changed files with 992 additions and 216 deletions.
11 changes: 8 additions & 3 deletions haystack/database/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import abstractmethod, ABC
from typing import Any, Optional, Dict, List

from pydantic import BaseModel, Field
Expand All @@ -19,10 +19,11 @@ class Document(BaseModel):
tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data")


class BaseDocumentStore:
class BaseDocumentStore(ABC):
"""
Base class for implementing Document Stores.
"""
index: Optional[str]

@abstractmethod
def write_documents(self, documents: List[dict]):
Expand All @@ -45,5 +46,9 @@ def get_document_count(self) -> int:
pass

@abstractmethod
def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
def query_by_embedding(self,
query_emb: List[float],
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
pass
56 changes: 46 additions & 10 deletions haystack/database/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
name_field: str = "name",
external_source_id_field: str = "external_source_id",
embedding_field: Optional[str] = None,
embedding_dim: Optional[str] = None,
embedding_dim: Optional[int] = None,
custom_mapping: Optional[dict] = None,
excluded_meta_data: Optional[list] = None,
faq_question_field: Optional[str] = None,
Expand Down Expand Up @@ -127,7 +127,6 @@ def get_document_count(self) -> int:
def get_all_documents(self) -> List[Document]:
result = scan(self.client, query={"query": {"match_all": {}}}, index=self.index)
documents = [self._convert_es_hit_to_document(hit) for hit in result]

return documents

def query(
Expand Down Expand Up @@ -182,7 +181,14 @@ def query(
documents = [self._convert_es_hit_to_document(hit) for hit in result]
return documents

def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
def query_by_embedding(self,
query_emb: List[float],
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
if index is None:
index = self.index

if not self.embedding_field:
raise RuntimeError("Please specify arg `embedding_field` in ElasticsearchDocumentStore()")
else:
Expand All @@ -202,18 +208,21 @@ def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_
}
} # type: Dict[str,Any]

if candidate_doc_ids:
body["query"]["script_score"]["query"] = {
"bool": {
"should": [{"match_all": {}}],
"filter": [{"terms": {"_id": candidate_doc_ids}}]
}}
if filters:
filter_clause = []
for key, values in filters.items():
filter_clause.append(
{
"terms": {key: values}
}
)
body["query"]["bool"]["filter"] = filter_clause

if self.excluded_meta_data:
body["_source"] = {"excludes": self.excluded_meta_data}

logger.debug(f"Retriever query: {body}")
result = self.client.search(index=self.index, body=body)["hits"]["hits"]
result = self.client.search(index=index, body=body)["hits"]["hits"]

documents = [self._convert_es_hit_to_document(hit, score_adjustment=-1) for hit in result]
return documents
Expand All @@ -233,6 +242,33 @@ def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> D
)
return document

def update_embeddings(self, retriever):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever
:return: None
"""

docs = self.get_all_documents()
passages = [d.text for d in docs]
logger.info(f"Updating embeddings for {len(passages)} docs ...")
embeddings = retriever.embed_passages(passages)

assert len(docs) == len(embeddings)

doc_updates = []
for doc, emb in zip(docs, embeddings):
update = {"_op_type": "update",
"_index": self.index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)

bulk(self.client, doc_updates, request_timeout=300)

def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "feedback"):
"""
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
Expand Down
13 changes: 12 additions & 1 deletion haystack/database/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, embedding_field: Optional[str] = None):
self.docs = {} # type: Dict[str, Any]
self.doc_tags = {} # type: Dict[str, Any]
self.embedding_field = embedding_field
self.index = None

def write_documents(self, documents: List[dict]):
import hashlib
Expand Down Expand Up @@ -64,10 +65,20 @@ def _convert_memory_hit_to_document(self, hit: Tuple[Any, Any], doc_id: Optional
)
return document

def query_by_embedding(self, query_emb: List[float], top_k: int = 10, candidate_doc_ids: Optional[List[str]] = None) -> List[Document]:
def query_by_embedding(self,
query_emb: List[float],
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:

from numpy import dot
from numpy.linalg import norm

if filters:
raise NotImplementedError("Setting `filters` is currently not supported in "
"InMemoryDocumentStore.query_by_embedding(). Please remove filters or "
"use a different DocumentStore (e.g. ElasticsearchDocumentStore).")

if self.embedding_field is None:
return []

Expand Down
10 changes: 10 additions & 0 deletions haystack/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,13 @@ def _convert_sql_row_to_document(self, row) -> DocumentSchema:
tags=row.tags
)
return document

def query_by_embedding(self,
query_emb: List[float],
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[DocumentSchema]:

raise NotImplementedError("SQLDocumentStore is currently not supporting embedding queries. "
"Change the query type (e.g. by choosing a different retriever) "
"or change the DocumentStore (e.g. to ElasticsearchDocumentStore)")
18 changes: 3 additions & 15 deletions haystack/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,11 @@ def get_answers_via_similar_questions(self, question: str, top_k_retriever: int

results = {"question": question, "answers": []} # type: Dict[str, Any]

# 1) Optional: reduce the search space via document tags
if filters:
logging.info(f"Apply filters: {filters}")
candidate_doc_ids = self.retriever.document_store.get_document_ids_by_tags(filters) # type: ignore
logger.info(f"Got candidate IDs due to filters: {candidate_doc_ids}")

if len(candidate_doc_ids) == 0:
# We didn't find any doc matching the filters
return results
# 1) Apply retriever to match similar questions via cosine similarity of embeddings
documents = self.retriever.retrieve(question, top_k=top_k_retriever, filters=filters)

else:
candidate_doc_ids = None # type: ignore

# 2) Apply retriever to match similar questions via cosine similarity of embeddings
documents = self.retriever.retrieve(question, top_k=top_k_retriever, candidate_doc_ids=candidate_doc_ids) # type: ignore

# 3) Format response
# 2) Format response
for doc in documents:
#TODO proper calibratation of pseudo probabilities
cur_answer = {"question": doc.question, "answer": doc.text, "context": doc.text, # type: ignore
Expand Down
1 change: 1 addition & 0 deletions haystack/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


class BaseRetriever(ABC):

@abstractmethod
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
pass
Loading

0 comments on commit 07ecfb6

Please sign in to comment.