Skip to content

Commit

Permalink
Prevent losing names of utilized components when loaded from config (#…
Browse files Browse the repository at this point in the history
…2525)

* Prevent losing names of utilized components when loaded from config

* Update Documentation & Code Style

* update test

* fix failing tests

* Update Documentation & Code Style

* fix even more tests

* Update Documentation & Code Style

* incorporate review feedback

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tstadel and github-actions[bot] authored May 18, 2022
1 parent 110b9c2 commit f6e3a63
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 14 deletions.
12 changes: 12 additions & 0 deletions docs/_src/api/api/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def root_node() -> Optional[str]

Returns the root node of the pipeline's graph.

<a id="base.Pipeline.components"></a>

#### Pipeline.components

```python
@property
def components() -> Dict[str, BaseComponent]
```

Returns all components used by this pipeline.
Note that this also includes such components that are being utilized by other components only and are not being used as a pipeline node directly.

<a id="base.Pipeline.to_code"></a>

#### Pipeline.to\_code
Expand Down
4 changes: 3 additions & 1 deletion haystack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,17 @@ def get_subclass(cls, component_type: str) -> Type[BaseComponent]:
return subclass

@classmethod
def _create_instance(cls, component_type: str, component_params: Dict[str, Any]):
def _create_instance(cls, component_type: str, component_params: Dict[str, Any], name: Optional[str] = None):
"""
Returns an instance of the given subclass of BaseComponent.
:param component_type: name of the component class to load.
:param component_params: parameters to pass to the __init__() for the component.
:param name: name of the component instance
"""
subclass = cls.get_subclass(component_type)
instance = subclass(**component_params)
instance.name = name
return instance

@abstractmethod
Expand Down
46 changes: 36 additions & 10 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,31 @@ def root_node(self) -> Optional[str]:

@property
def components(self) -> Dict[str, BaseComponent]:
return {
name: attributes["component"]
for name, attributes in self.graph.nodes.items()
if not isinstance(attributes["component"], RootNode)
}
"""
Returns all components used by this pipeline.
Note that this also includes such components that are being utilized by other components only and are not being used as a pipeline node directly.
"""
all_components = self._find_all_components()
return {component.name: component for component in all_components if component.name is not None}

def _find_all_components(self, seed_components: List[BaseComponent] = None) -> Set[BaseComponent]:
"""
Finds all components given the provided seed components.
Components are found by traversing the provided seed components and their utilized components.
If seed_components is None, the node components (except the root node) of the pipeline will be used as seed components.
"""
if seed_components is None:
seed_components = [
attributes["component"]
for attributes in self.graph.nodes.values()
if not isinstance(attributes["component"], RootNode)
]

distinct_components = set(seed_components)
for component in seed_components:
sub_components = self._find_all_components(component.utilized_components)
distinct_components.update(sub_components)
return distinct_components

def to_code(
self, pipeline_variable_name: str = "pipeline", generate_imports: bool = True, add_comment: bool = False
Expand Down Expand Up @@ -373,11 +393,15 @@ def add_node(self, component: BaseComponent, name: str, inputs: List[str]):
)
self.graph = _init_pipeline_graph(root_node_name=candidate_roots[0])

component_definitions = get_component_definitions(pipeline_config=self.get_config())

# Check for duplicates before adding the definition
if name in component_definitions.keys():
# Check for duplicate names before adding the component
# Note that the very same component must be addable multiple times:
# E.g. for indexing pipelines it's common to add a retriever first and a document store afterwards.
# The document store is already being used by the retriever however.
# Thus the very same document store will be added twice, first as a subcomponent of the retriever and second as a first level node.
if name in self.components.keys() and self.components[name] != component:
raise PipelineConfigError(f"A node named '{name}' is already in the pipeline. Choose another name.")

component_definitions = get_component_definitions(pipeline_config=self.get_config())
component_definitions[name] = component._component_config

# Name any nested component before adding them
Expand Down Expand Up @@ -1411,7 +1435,9 @@ def _load_or_get_component(cls, name: str, definitions: dict, components: dict):
value
] # substitute reference (string) with the component object.

component_instance = BaseComponent._create_instance(component_type, component_params)
component_instance = BaseComponent._create_instance(
component_type=component_type, component_params=component_params, name=name
)
components[name] = component_instance
return component_instance

Expand Down
2 changes: 1 addition & 1 deletion haystack/pipelines/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,6 @@ def load_from_pipeline_config(pipeline_config: dict, component_name: str):
component_params[key] = _RayDeploymentWrapper.load_from_pipeline_config(pipeline_config, value)

component_instance = BaseComponent._create_instance(
component_type=component_config["type"], component_params=component_params
component_type=component_config["type"], component_params=component_params, name=component_name
)
return component_instance
12 changes: 10 additions & 2 deletions haystack/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def generate_code(
component_definitions=component_definitions,
component_variable_names=component_variable_names,
dependency_graph=component_dependency_graph,
pipeline_definition=pipeline_definition,
)
pipeline_code = _generate_pipeline_code(
pipeline_definition=pipeline_definition,
Expand Down Expand Up @@ -109,7 +110,10 @@ def _generate_pipeline_code(


def _generate_components_code(
component_definitions: Dict[str, Any], component_variable_names: Dict[str, str], dependency_graph: DiGraph
component_definitions: Dict[str, Any],
component_variable_names: Dict[str, str],
dependency_graph: DiGraph,
pipeline_definition: Dict[str, Any],
) -> str:
code = ""
declarations = {}
Expand All @@ -121,7 +125,11 @@ def _generate_components_code(
for key, value in definition.get("params", {}).items()
}
init_args = ", ".join(f"{key}={value}" for key, value in param_value_dict.items())
declarations[name] = f"{variable_name} = {class_name}({init_args})"
declaration = f"{variable_name} = {class_name}({init_args})"
# set name of subcomponents explicitly if it's not the default name as it won't be set via Pipeline.add_node()
if name != class_name and name not in (node["name"] for node in pipeline_definition["nodes"]):
declaration = f'{declaration}\n{variable_name}.name = "{name}"'
declarations[name] = declaration

ordered_components = nx.topological_sort(dependency_graph)
ordered_declarations = [declarations[component] for component in ordered_components]
Expand Down
4 changes: 4 additions & 0 deletions test/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def test_generate_code_imports():
"from haystack.pipelines import Pipeline\n"
"\n"
"document_store = ElasticsearchDocumentStore()\n"
'document_store.name = "DocumentStore"\n'
"retri = BM25Retriever(document_store=document_store)\n"
"retri_2 = TfidfRetriever(document_store=document_store)\n"
"\n"
Expand Down Expand Up @@ -497,6 +498,7 @@ def test_generate_code_imports_no_pipeline_cls():
"from haystack.nodes import BM25Retriever\n"
"\n"
"document_store = ElasticsearchDocumentStore()\n"
'document_store.name = "DocumentStore"\n'
"retri = BM25Retriever(document_store=document_store)\n"
"\n"
"p = Pipeline()\n"
Expand Down Expand Up @@ -524,6 +526,7 @@ def test_generate_code_comment():
"from haystack.pipelines import Pipeline\n"
"\n"
"document_store = ElasticsearchDocumentStore()\n"
'document_store.name = "DocumentStore"\n'
"retri = BM25Retriever(document_store=document_store)\n"
"\n"
"p = Pipeline()\n"
Expand Down Expand Up @@ -717,6 +720,7 @@ def test_load_from_deepset_cloud_query():
assert isinstance(retriever, BM25Retriever)
assert isinstance(document_store, DeepsetCloudDocumentStore)
assert document_store == query_pipeline.get_document_store()
assert document_store.name == "DocumentStore"

prediction = query_pipeline.run(query="man on horse", params={})

Expand Down
3 changes: 3 additions & 0 deletions test/pipelines/test_pipeline_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ def __init__(self, other_node: OtherNode):
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert isinstance(pipeline.get_node("custom_node"), CustomNode)
assert isinstance(pipeline.get_node("custom_node").other_node, OtherNode)
assert pipeline.get_node("custom_node").name == "custom_node"
assert pipeline.get_node("custom_node").other_node.name == "other_node"


def test_load_yaml_custom_component_with_helper_class_in_init(tmp_path):
Expand Down

0 comments on commit f6e3a63

Please sign in to comment.