Skip to content

Commit

Permalink
Fix JoinAnswer/JoinNode (deepset-ai#2612)
Browse files Browse the repository at this point in the history
* fix join nodes

* Update Documentation & Code Style

* fix unused import

* change arg order

* Update Documentation & Code Style

* fix kwargs check

* add warning when there is only one input node

* Update Documentation & Code Style

* fix type hint

* fix wrong import order

* Update Documentation & Code Style

* undo kwargs

* add accidentally deleted newline#

* fix type hint

* fix type hint

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and andrch-FS committed Jul 26, 2022
1 parent 2a93b9a commit 5e434a5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/_src/api/api/other.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This ensures that your output is in a compatible format.
## JoinDocuments

```python
class JoinDocuments(BaseComponent)
class JoinDocuments(JoinNode)
```

A node to join documents outputted by multiple retriever nodes.
Expand Down Expand Up @@ -61,7 +61,7 @@ to each retriever score. This param is not compatible with the `concatenate` joi
## JoinAnswers

```python
class JoinAnswers(BaseComponent)
class JoinAnswers(JoinNode)
```

A node to join `Answer`s produced by multiple `Reader` nodes.
Expand Down
1 change: 1 addition & 0 deletions haystack/nodes/other/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from haystack.nodes.other.join_docs import JoinDocuments
from haystack.nodes.other.route_documents import RouteDocuments
from haystack.nodes.other.join_answers import JoinAnswers
from haystack.nodes.other.join import JoinNode
76 changes: 76 additions & 0 deletions haystack/nodes/other/join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from abc import abstractmethod
from typing import Optional, List, Tuple, Dict, Union, Any
import warnings

from haystack import MultiLabel, Document, Answer
from haystack.nodes.base import BaseComponent


class JoinNode(BaseComponent):
def run( # type: ignore
self,
inputs: Optional[List[dict]] = None,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_accumulated(inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_accumulated(
inputs=[
{
"query": query,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"answers": answers,
}
],
top_k_join=top_k_join,
)

@abstractmethod
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass

def run_batch( # type: ignore
self,
inputs: Optional[List[dict]] = None,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
answers: Optional[List[Answer]] = None,
top_k_join: Optional[int] = None,
) -> Tuple[Dict, str]:
if inputs:
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join)
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.")
return self.run_batch_accumulated(
inputs=[
{
"queries": queries,
"file_paths": file_paths,
"labels": labels,
"documents": documents,
"meta": meta,
"params": params,
"debug": debug,
"answers": answers,
}
],
top_k_join=top_k_join
)

@abstractmethod
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]:
pass
8 changes: 4 additions & 4 deletions haystack/nodes/other/join_answers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional, List, Dict, Tuple

from haystack.schema import Answer
from haystack.nodes.base import BaseComponent
from haystack.nodes.other.join import JoinNode


class JoinAnswers(BaseComponent):
class JoinAnswers(JoinNode):
"""
A node to join `Answer`s produced by multiple `Reader` nodes.
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
self.top_k_join = top_k_join
self.sort_by_score = sort_by_score

def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
def run_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
reader_results = [inp["answers"] for inp in inputs]

if not top_k_join:
Expand All @@ -61,7 +61,7 @@ def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dic
else:
raise ValueError(f"Invalid join_mode: {self.join_mode}")

def run_batch(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
def run_batch_accumulated(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
output_ans = []
incoming_edges = [inp["answers"] for inp in inputs]
# At each idx, we find predicted answers for the same query from different Readers
Expand Down
8 changes: 4 additions & 4 deletions haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import Optional, List

from haystack.schema import Document
from haystack.nodes.base import BaseComponent
from haystack.nodes.other.join import JoinNode


class JoinDocuments(BaseComponent):
class JoinDocuments(JoinNode):
"""
A node to join documents outputted by multiple retriever nodes.
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
self.top_k_join = top_k_join

def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
results = [inp["documents"] for inp in inputs]
document_map = {doc.id: doc for result in results for doc in result}

Expand Down Expand Up @@ -77,7 +77,7 @@ def run(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ig

return output, "output_1"

def run_batch(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None): # type: ignore
# Join single document lists
if isinstance(inputs[0]["documents"][0], Document):
return self.run(inputs=inputs, top_k_join=top_k_join)
Expand Down

0 comments on commit 5e434a5

Please sign in to comment.