Skip to content

Commit

Permalink
Leverage the new metaclass to allow running on instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber committed Nov 29, 2023
1 parent 739fbad commit 39dca4b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 32 deletions.
1 change: 1 addition & 0 deletions pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
label: str,
*args,
parent: Optional[Composite] = None,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
19 changes: 1 addition & 18 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(
*args,
label: Optional[str] = None,
parent: Optional[Composite] = None,
run_after_init: bool = False,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
Expand Down Expand Up @@ -569,24 +570,6 @@ class SingleValue(Function, HasChannel):
`some_node.input.some_channel = my_svn_instance`.
"""

def __init__(
self,
node_function: callable,
*args,
label: Optional[str] = None,
parent: Optional[Workflow] = None,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
super().__init__(
node_function,
*args,
label=label,
parent=parent,
output_labels=output_labels,
**kwargs,
)

def _get_output_labels(self, output_labels: str | list[str] | tuple[str] | None):
output_labels = super()._get_output_labels(output_labels)
if len(output_labels) > 1:
Expand Down
1 change: 1 addition & 0 deletions pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
graph_creator: callable[[Macro], None],
label: Optional[str] = None,
parent: Optional[Composite] = None,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
11 changes: 9 additions & 2 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_nodes_in_data_tree,
set_run_connections_according_to_linear_dag,
)
from pyiron_workflow.util import SeabornColors
from pyiron_workflow.util import AbstractHasPost, SeabornColors

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -66,7 +66,7 @@ def wrapped_method(node: Node, *args, **kwargs): # rather node:Node
return wrapped_method


class Node(HasToDict, ABC):
class Node(HasToDict, ABC, metaclass=AbstractHasPost):
"""
Nodes are elements of a computational graph.
They have inputs and outputs to interface with the wider world, and perform some
Expand Down Expand Up @@ -194,6 +194,7 @@ def __init__(
label: str,
*args,
parent: Optional[Composite] = None,
run_after_init: bool = False,
**kwargs,
):
"""
Expand All @@ -203,6 +204,8 @@ def __init__(
Args:
label (str): A name for this node.
*args: Arguments passed on with `super`.
parent: (Composite|None): The composite node that owns this as a child.
run_after_init (bool): Whether to run at the end of initialization.
**kwargs: Keyword arguments passed on with `super`.
"""
super().__init__(*args, **kwargs)
Expand All @@ -220,6 +223,10 @@ def __init__(
# (or create) an executor process without ever trying to pickle a `_thread.lock`
self.future: None | Future = None

def __post__(self, *args, run_after_init: bool = False, **kwargs):
if run_after_init:
self.run()

@property
@abstractmethod
def inputs(self) -> Inputs:
Expand Down
1 change: 1 addition & 0 deletions pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
self,
label: str,
*nodes: Node,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
30 changes: 18 additions & 12 deletions tests/unit/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sys import version_info
import unittest

from pyiron_workflow.channels import InputData, OutputData
from pyiron_workflow.channels import InputData, OutputData, NotData
from pyiron_workflow.files import DirectoryObject
from pyiron_workflow.io import Inputs, Outputs
from pyiron_workflow.node import Node
Expand All @@ -16,10 +16,12 @@ def add_one(x):
class ANode(Node):
"""To de-abstract the class"""

def __init__(self, label):
def __init__(self, label, run_after_init=False, x=None):
super().__init__(label=label)
self._inputs = Inputs(InputData("x", self, type_hint=int))
self._outputs = Outputs(OutputData("y", self, type_hint=int))
if x is not None:
self.inputs.x = x

@property
def inputs(self) -> Inputs:
Expand Down Expand Up @@ -48,15 +50,9 @@ def to_dict(self):
@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestNode(unittest.TestCase):
def setUp(self):
n1 = ANode("start")
n2 = ANode("middle")
n3 = ANode("end")
n1.inputs.x = 0
n2.inputs.x = n1.outputs.y
n3.inputs.x = n2.outputs.y
self.n1 = n1
self.n2 = n2
self.n3 = n3
self.n1 = ANode("start", x=0)
self.n2 = ANode("middle", x=self.n1.outputs.y)
self.n3 = ANode("end", x=self.n2.outputs.y)

def test_set_input_values(self):
n = ANode("some_node")
Expand Down Expand Up @@ -334,4 +330,14 @@ def test_draw(self):
# No matter what happens in the tests, clean up after yourself
self.n1.working_directory.delete()


def test_run_after_init(self):
self.assertIs(
self.n1.outputs.y.value,
NotData,
msg="By default, nodes should not be getting run until asked"
)
self.assertEqual(
1,
ANode("right_away", run_after_init=True, x=0).outputs.y.value,
msg="With run_after_init, the node should run right away"
)

0 comments on commit 39dca4b

Please sign in to comment.