Skip to content

Commit

Permalink
Add support for positional args in pipeline.get_config() (#2478)
Browse files Browse the repository at this point in the history
* add support for positional args in pipeline.get_config()

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tstadel and github-actions[bot] authored May 2, 2022
1 parent 7d6b3fe commit 509944f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
13 changes: 8 additions & 5 deletions haystack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ def wrapper_exportable_to_yaml(self, *args, **kwargs):
# Call the actuall __init__ function with all the arguments
init_func(self, *args, **kwargs)

# Warn for unnamed input params - should be rare
if args:
logger.warning(
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
)
# Create the configuration dictionary if it doesn't exist yet
if not self._component_config:
self._component_config = {"params": {}, "type": type(self).__name__}
Expand All @@ -46,6 +41,14 @@ def wrapper_exportable_to_yaml(self, *args, **kwargs):
for k, v in kwargs.items():
self._component_config["params"][k] = v

# Store unnamed input parameters in self._component_config too by inferring their names
sig = inspect.signature(init_func)
parameter_names = list(sig.parameters.keys())
# we can be sure that the first one is always "self"
arg_names = parameter_names[1 : 1 + len(args)]
for arg, arg_name in zip(args, arg_names):
self._component_config["params"][arg_name] = arg

return wrapper_exportable_to_yaml


Expand Down
12 changes: 4 additions & 8 deletions test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,21 +386,17 @@ def __init__(self, param: int):
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}


def test_get_config_custom_node_with_positional_params(caplog):
def test_get_config_custom_node_with_positional_params():
class CustomNode(MockNode):
def __init__(self, param: int = 1):
super().__init__()
self.param = param

pipeline = Pipeline()
with caplog.at_level(logging.WARNING):
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
assert (
"Unnamed __init__ parameters will not be saved to YAML "
"if Pipeline.save_to_yaml() is called" in caplog.text
)
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])

assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {}
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}


def test_generate_code_simple_pipeline():
Expand Down

0 comments on commit 509944f

Please sign in to comment.