Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in loading code from yaml #2705

Merged
merged 5 commits into from
Jun 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions haystack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
def exportable_to_yaml(init_func):
"""
Decorator that saves the init parameters of a node that later can
be used with exporting YAML configuration of a Pipeline.
be used with exporting YAML configuration of a Pipeline. We ensure
that only params passed to the __init__ function of the implementation
are saved, ignoring calls to the ancestors.
"""

@wraps(init_func)
Expand All @@ -32,12 +34,13 @@ def wrapper_exportable_to_yaml(self, *args, **kwargs):
if not self._component_config:
self._component_config = {"params": {}, "type": type(self).__name__}

# Make sure it runs only on the __init__of the implementations, not in superclasses
# NOTE: we use '.endswith' because inner classes's __qualname__ will include the parent class'
# name, like: ParentClass.InnerClass.__init__.
# Inner classes are heavily used in tests.
if init_func.__qualname__.endswith(f"{self.__class__.__name__}.{init_func.__name__}"):

# NOTE: inner classes constructor's __qualname__ will include the outer class' name,
# e.g. "OuterClass.InnerClass.__init__". We then take only the last two parts of the
# fully qualified name, in the previous example that would be "InnerClass.__init__"
name_components = init_func.__qualname__.split(".")
# Reconstruct the inner class' __qualname__ and compare with the __qualname__ of the implementation class.
# If the number of components is wrong, let the IndexError bubble up, there's nothing we can do anyways.
if f"{name_components[-2]}.{name_components[-1]}" == f"{self.__class__.__name__}.{init_func.__name__}":
# Store all the input parameters in self._component_config
args_as_kwargs = args_to_kwargs(args, init_func)
params = {**args_as_kwargs, **kwargs}
Expand Down Expand Up @@ -95,7 +98,7 @@ def type(self) -> str:
return self._component_config["type"]

def get_params(self, return_defaults: bool = False) -> Dict[str, Any]:
component_signature = self._get_signature()
component_signature = dict(inspect.signature(self.__class__).parameters)
params: Dict[str, Any] = {}
for key, value in self._component_config["params"].items():
if value != component_signature[key].default or return_defaults:
Expand Down Expand Up @@ -241,16 +244,6 @@ def _dispatch_run_general(self, run_method: Callable, **kwargs):
output["params"] = params
return output, stream

@classmethod
def _get_signature(cls) -> Dict[str, inspect.Parameter]:
component_classes = inspect.getmro(cls)
component_signature: Dict[str, inspect.Parameter] = {
param_key: parameter
for class_ in component_classes
for param_key, parameter in inspect.signature(class_).parameters.items()
}
return component_signature


class RootNode(BaseComponent):
"""
Expand Down