diff --git a/pyiron_workflow/composite.py b/pyiron_workflow/composite.py index 283bd1f4..3c280097 100644 --- a/pyiron_workflow/composite.py +++ b/pyiron_workflow/composite.py @@ -104,6 +104,7 @@ def __init__( label: str, *args, parent: Optional[Composite] = None, + run_after_init: bool = False, strict_naming: bool = True, inputs_map: Optional[dict | bidict] = None, outputs_map: Optional[dict | bidict] = None, diff --git a/pyiron_workflow/function.py b/pyiron_workflow/function.py index e4a36f92..1270f285 100644 --- a/pyiron_workflow/function.py +++ b/pyiron_workflow/function.py @@ -296,6 +296,7 @@ def __init__( *args, label: Optional[str] = None, parent: Optional[Composite] = None, + run_after_init: bool = False, output_labels: Optional[str | list[str] | tuple[str]] = None, **kwargs, ): @@ -569,24 +570,6 @@ class SingleValue(Function, HasChannel): `some_node.input.some_channel = my_svn_instance`. """ - def __init__( - self, - node_function: callable, - *args, - label: Optional[str] = None, - parent: Optional[Workflow] = None, - output_labels: Optional[str | list[str] | tuple[str]] = None, - **kwargs, - ): - super().__init__( - node_function, - *args, - label=label, - parent=parent, - output_labels=output_labels, - **kwargs, - ) - def _get_output_labels(self, output_labels: str | list[str] | tuple[str] | None): output_labels = super()._get_output_labels(output_labels) if len(output_labels) > 1: diff --git a/pyiron_workflow/macro.py b/pyiron_workflow/macro.py index ab0636c6..f245f722 100644 --- a/pyiron_workflow/macro.py +++ b/pyiron_workflow/macro.py @@ -175,6 +175,7 @@ def __init__( graph_creator: callable[[Macro], None], label: Optional[str] = None, parent: Optional[Composite] = None, + run_after_init: bool = False, strict_naming: bool = True, inputs_map: Optional[dict | bidict] = None, outputs_map: Optional[dict | bidict] = None, diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 4a955753..338746c2 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -22,7 +22,7 @@ get_nodes_in_data_tree, set_run_connections_according_to_linear_dag, ) -from pyiron_workflow.util import SeabornColors +from pyiron_workflow.util import AbstractHasPost, SeabornColors if TYPE_CHECKING: from pathlib import Path @@ -66,7 +66,7 @@ def wrapped_method(node: Node, *args, **kwargs): # rather node:Node return wrapped_method -class Node(HasToDict, ABC): +class Node(HasToDict, ABC, metaclass=AbstractHasPost): """ Nodes are elements of a computational graph. They have inputs and outputs to interface with the wider world, and perform some @@ -194,6 +194,7 @@ def __init__( label: str, *args, parent: Optional[Composite] = None, + run_after_init: bool = False, **kwargs, ): """ @@ -203,6 +204,8 @@ def __init__( Args: label (str): A name for this node. *args: Arguments passed on with `super`. + parent: (Composite|None): The composite node that owns this as a child. + run_after_init (bool): Whether to run at the end of initialization. **kwargs: Keyword arguments passed on with `super`. """ super().__init__(*args, **kwargs) @@ -220,6 +223,10 @@ def __init__( # (or create) an executor process without ever trying to pickle a `_thread.lock` self.future: None | Future = None + def __post__(self, *args, run_after_init: bool = False, **kwargs): + if run_after_init: + self.run() + @property @abstractmethod def inputs(self) -> Inputs: diff --git a/pyiron_workflow/util.py b/pyiron_workflow/util.py index b7ab487e..9202a177 100644 --- a/pyiron_workflow/util.py +++ b/pyiron_workflow/util.py @@ -1,3 +1,5 @@ +from abc import ABCMeta + from pyiron_base import state logger = state.logger @@ -43,3 +45,24 @@ class SeabornColors: cyan = "#17becf" white = "#ffffff" black = "#000000" + + +class HasPost(type): + """ + A metaclass for adding a `__post__` method which has a compatible signature with + `__init__` (and indeed receives all its input), but is guaranteed to be called + only _after_ `__init__` is totally finished. + + Based on @jsbueno's reply in [this discussion](https://discuss.python.org/t/add-a-post-method-equivalent-to-the-new-method-but-called-after-init/5449/11) + """ + + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if post := getattr(cls, "__post__", False): + post(instance, *args, **kwargs) + return instance + + +class AbstractHasPost(HasPost, ABCMeta): + # Just for resolving metaclass conflic for ABC classes that have post + pass diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 789f57c9..8c148765 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -182,6 +182,7 @@ def __init__( self, label: str, *nodes: Node, + run_after_init: bool = False, strict_naming: bool = True, inputs_map: Optional[dict | bidict] = None, outputs_map: Optional[dict | bidict] = None, diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index daedcab6..ca6ec163 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -3,7 +3,7 @@ from sys import version_info import unittest -from pyiron_workflow.channels import InputData, OutputData +from pyiron_workflow.channels import InputData, OutputData, NotData from pyiron_workflow.files import DirectoryObject from pyiron_workflow.io import Inputs, Outputs from pyiron_workflow.node import Node @@ -16,10 +16,12 @@ def add_one(x): class ANode(Node): """To de-abstract the class""" - def __init__(self, label): + def __init__(self, label, run_after_init=False, x=None): super().__init__(label=label) self._inputs = Inputs(InputData("x", self, type_hint=int)) self._outputs = Outputs(OutputData("y", self, type_hint=int)) + if x is not None: + self.inputs.x = x @property def inputs(self) -> Inputs: @@ -48,15 +50,9 @@ def to_dict(self): @unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+") class TestNode(unittest.TestCase): def setUp(self): - n1 = ANode("start") - n2 = ANode("middle") - n3 = ANode("end") - n1.inputs.x = 0 - n2.inputs.x = n1.outputs.y - n3.inputs.x = n2.outputs.y - self.n1 = n1 - self.n2 = n2 - self.n3 = n3 + self.n1 = ANode("start", x=0) + self.n2 = ANode("middle", x=self.n1.outputs.y) + self.n3 = ANode("end", x=self.n2.outputs.y) def test_set_input_values(self): n = ANode("some_node") @@ -334,4 +330,14 @@ def test_draw(self): # No matter what happens in the tests, clean up after yourself self.n1.working_directory.delete() - + def test_run_after_init(self): + self.assertIs( + self.n1.outputs.y.value, + NotData, + msg="By default, nodes should not be getting run until asked" + ) + self.assertEqual( + 1, + ANode("right_away", run_after_init=True, x=0).outputs.y.value, + msg="With run_after_init, the node should run right away" + ) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 57c2abd9..9a5251c0 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -14,3 +14,35 @@ def test_dot_dict(self): self.assertEqual("towel", dd["bar"], msg="Dot assignment should be equivalent.") self.assertListEqual(dd.to_list(), [42, "towel"]) + + def test_has_post_metaclass(self): + class Foo(metaclass=util.HasPost): + def __init__(self, x=0): + self.x = x + self.y = x + self.z = x + self.x += 1 + + @property + def data(self): + return self.x, self.y, self.z + + class Bar(Foo): + def __init__(self, x=0, extra=1): + super().__init__(x) + + def __post__(self, *args, extra=1, **kwargs): + self.z = self.x + extra + + self.assertTupleEqual( + (1, 0, 0), + Foo().data, + msg="It should be fine to have this metaclass but not define post" + ) + + self.assertTupleEqual( + (1, 0, 2), + Bar().data, + msg="Metaclass should be inherited, able to use input, and happen _after_ " + "__init__" + )