Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can run right away #96

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
label: str,
*args,
parent: Optional[Composite] = None,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
19 changes: 1 addition & 18 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(
*args,
label: Optional[str] = None,
parent: Optional[Composite] = None,
run_after_init: bool = False,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
Expand Down Expand Up @@ -569,24 +570,6 @@ class SingleValue(Function, HasChannel):
`some_node.input.some_channel = my_svn_instance`.
"""

def __init__(
self,
node_function: callable,
*args,
label: Optional[str] = None,
parent: Optional[Workflow] = None,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
super().__init__(
node_function,
*args,
label=label,
parent=parent,
output_labels=output_labels,
**kwargs,
)

def _get_output_labels(self, output_labels: str | list[str] | tuple[str] | None):
output_labels = super()._get_output_labels(output_labels)
if len(output_labels) > 1:
Expand Down
1 change: 1 addition & 0 deletions pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
graph_creator: callable[[Macro], None],
label: Optional[str] = None,
parent: Optional[Composite] = None,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
11 changes: 9 additions & 2 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_nodes_in_data_tree,
set_run_connections_according_to_linear_dag,
)
from pyiron_workflow.util import SeabornColors
from pyiron_workflow.util import AbstractHasPost, SeabornColors

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -66,7 +66,7 @@ def wrapped_method(node: Node, *args, **kwargs): # rather node:Node
return wrapped_method


class Node(HasToDict, ABC):
class Node(HasToDict, ABC, metaclass=AbstractHasPost):
"""
Nodes are elements of a computational graph.
They have inputs and outputs to interface with the wider world, and perform some
Expand Down Expand Up @@ -194,6 +194,7 @@ def __init__(
label: str,
*args,
parent: Optional[Composite] = None,
run_after_init: bool = False,
**kwargs,
):
"""
Expand All @@ -203,6 +204,8 @@ def __init__(
Args:
label (str): A name for this node.
*args: Arguments passed on with `super`.
parent: (Composite|None): The composite node that owns this as a child.
run_after_init (bool): Whether to run at the end of initialization.
**kwargs: Keyword arguments passed on with `super`.
"""
super().__init__(*args, **kwargs)
Expand All @@ -220,6 +223,10 @@ def __init__(
# (or create) an executor process without ever trying to pickle a `_thread.lock`
self.future: None | Future = None

def __post__(self, *args, run_after_init: bool = False, **kwargs):
if run_after_init:
self.run()

@property
@abstractmethod
def inputs(self) -> Inputs:
Expand Down
23 changes: 23 additions & 0 deletions pyiron_workflow/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import ABCMeta

from pyiron_base import state

logger = state.logger
Expand Down Expand Up @@ -43,3 +45,24 @@ class SeabornColors:
cyan = "#17becf"
white = "#ffffff"
black = "#000000"


class HasPost(type):
"""
A metaclass for adding a `__post__` method which has a compatible signature with
`__init__` (and indeed receives all its input), but is guaranteed to be called
only _after_ `__init__` is totally finished.

Based on @jsbueno's reply in [this discussion](https://discuss.python.org/t/add-a-post-method-equivalent-to-the-new-method-but-called-after-init/5449/11)
"""

def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if post := getattr(cls, "__post__", False):
post(instance, *args, **kwargs)
return instance


class AbstractHasPost(HasPost, ABCMeta):
# Just for resolving metaclass conflic for ABC classes that have post
pass
1 change: 1 addition & 0 deletions pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
self,
label: str,
*nodes: Node,
run_after_init: bool = False,
strict_naming: bool = True,
inputs_map: Optional[dict | bidict] = None,
outputs_map: Optional[dict | bidict] = None,
Expand Down
30 changes: 18 additions & 12 deletions tests/unit/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sys import version_info
import unittest

from pyiron_workflow.channels import InputData, OutputData
from pyiron_workflow.channels import InputData, OutputData, NotData
from pyiron_workflow.files import DirectoryObject
from pyiron_workflow.io import Inputs, Outputs
from pyiron_workflow.node import Node
Expand All @@ -16,10 +16,12 @@ def add_one(x):
class ANode(Node):
"""To de-abstract the class"""

def __init__(self, label):
def __init__(self, label, run_after_init=False, x=None):
super().__init__(label=label)
self._inputs = Inputs(InputData("x", self, type_hint=int))
self._outputs = Outputs(OutputData("y", self, type_hint=int))
if x is not None:
self.inputs.x = x

@property
def inputs(self) -> Inputs:
Expand Down Expand Up @@ -48,15 +50,9 @@ def to_dict(self):
@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestNode(unittest.TestCase):
def setUp(self):
n1 = ANode("start")
n2 = ANode("middle")
n3 = ANode("end")
n1.inputs.x = 0
n2.inputs.x = n1.outputs.y
n3.inputs.x = n2.outputs.y
self.n1 = n1
self.n2 = n2
self.n3 = n3
self.n1 = ANode("start", x=0)
self.n2 = ANode("middle", x=self.n1.outputs.y)
self.n3 = ANode("end", x=self.n2.outputs.y)

def test_set_input_values(self):
n = ANode("some_node")
Expand Down Expand Up @@ -334,4 +330,14 @@ def test_draw(self):
# No matter what happens in the tests, clean up after yourself
self.n1.working_directory.delete()


def test_run_after_init(self):
self.assertIs(
self.n1.outputs.y.value,
NotData,
msg="By default, nodes should not be getting run until asked"
)
self.assertEqual(
1,
ANode("right_away", run_after_init=True, x=0).outputs.y.value,
msg="With run_after_init, the node should run right away"
)
32 changes: 32 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,35 @@ def test_dot_dict(self):
self.assertEqual("towel", dd["bar"], msg="Dot assignment should be equivalent.")

self.assertListEqual(dd.to_list(), [42, "towel"])

def test_has_post_metaclass(self):
class Foo(metaclass=util.HasPost):
def __init__(self, x=0):
self.x = x
self.y = x
self.z = x
self.x += 1

@property
def data(self):
return self.x, self.y, self.z

class Bar(Foo):
def __init__(self, x=0, extra=1):
super().__init__(x)

def __post__(self, *args, extra=1, **kwargs):
self.z = self.x + extra

self.assertTupleEqual(
(1, 0, 0),
Foo().data,
msg="It should be fine to have this metaclass but not define post"
)

self.assertTupleEqual(
(1, 0, 2),
Bar().data,
msg="Metaclass should be inherited, able to use input, and happen _after_ "
"__init__"
)
Loading