forked from deepset-ai/haystack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix JoinAnswer/JoinNode (deepset-ai#2612)
* fix join nodes * Update Documentation & Code Style * fix unused import * change arg order * Update Documentation & Code Style * fix kwargs check * add warning when there is only one input node * Update Documentation & Code Style * fix type hint * fix wrong import order * Update Documentation & Code Style * undo kwargs * add accidentally deleted newline# * fix type hint * fix type hint Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
2a93b9a
commit 5e434a5
Showing
5 changed files
with
87 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from abc import abstractmethod | ||
from typing import Optional, List, Tuple, Dict, Union, Any | ||
import warnings | ||
|
||
from haystack import MultiLabel, Document, Answer | ||
from haystack.nodes.base import BaseComponent | ||
|
||
|
||
class JoinNode(BaseComponent): | ||
def run( # type: ignore | ||
self, | ||
inputs: Optional[List[dict]] = None, | ||
query: Optional[str] = None, | ||
file_paths: Optional[List[str]] = None, | ||
labels: Optional[MultiLabel] = None, | ||
documents: Optional[List[Document]] = None, | ||
meta: Optional[dict] = None, | ||
answers: Optional[List[Answer]] = None, | ||
top_k_join: Optional[int] = None, | ||
) -> Tuple[Dict, str]: | ||
if inputs: | ||
return self.run_accumulated(inputs, top_k_join=top_k_join) | ||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.") | ||
return self.run_accumulated( | ||
inputs=[ | ||
{ | ||
"query": query, | ||
"file_paths": file_paths, | ||
"labels": labels, | ||
"documents": documents, | ||
"meta": meta, | ||
"answers": answers, | ||
} | ||
], | ||
top_k_join=top_k_join, | ||
) | ||
|
||
@abstractmethod | ||
def run_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: | ||
pass | ||
|
||
def run_batch( # type: ignore | ||
self, | ||
inputs: Optional[List[dict]] = None, | ||
queries: Optional[Union[str, List[str]]] = None, | ||
file_paths: Optional[List[str]] = None, | ||
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, | ||
documents: Optional[Union[List[Document], List[List[Document]]]] = None, | ||
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, | ||
params: Optional[dict] = None, | ||
debug: Optional[bool] = None, | ||
answers: Optional[List[Answer]] = None, | ||
top_k_join: Optional[int] = None, | ||
) -> Tuple[Dict, str]: | ||
if inputs: | ||
return self.run_batch_accumulated(inputs=inputs, top_k_join=top_k_join) | ||
warnings.warn("You are using a JoinNode with only one input. This is usually equivalent to a no-op.") | ||
return self.run_batch_accumulated( | ||
inputs=[ | ||
{ | ||
"queries": queries, | ||
"file_paths": file_paths, | ||
"labels": labels, | ||
"documents": documents, | ||
"meta": meta, | ||
"params": params, | ||
"debug": debug, | ||
"answers": answers, | ||
} | ||
], | ||
top_k_join=top_k_join | ||
) | ||
|
||
@abstractmethod | ||
def run_batch_accumulated(self, inputs: List[dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters