diff --git a/pyiron_workflow/function.py b/pyiron_workflow/function.py index cc3295ce..18cc5198 100644 --- a/pyiron_workflow/function.py +++ b/pyiron_workflow/function.py @@ -1,16 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import Any -from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory +from pyiron_workflow.io_preview import StaticNode, ScrapesIO from pyiron_workflow.snippets.colors import SeabornColors +from pyiron_workflow.snippets.factory import classfactory -if TYPE_CHECKING: - from pyiron_workflow.composite import Composite - -class Function(DecoratedNode, ABC): +class Function(StaticNode, ScrapesIO, ABC): """ Function nodes wrap an arbitrary python function. @@ -347,64 +345,50 @@ def color(self) -> str: return SeabornColors.green -as_function_node = decorated_node_decorator_factory(Function, Function.node_function) - - -def function_node( - node_function: callable, - *args, - label: Optional[str] = None, - parent: Optional[Composite] = None, - overwrite_save: bool = False, - run_after_init: bool = False, - storage_backend: Optional[Literal["h5io", "tinybase"]] = None, - save_after_run: bool = False, - output_labels: Optional[str | tuple[str]] = None, - validate_output_labels: bool = True, - **kwargs, +@classfactory +def function_node_factory( + node_function: callable, validate_output_labels: bool, /, *output_labels ): - """ - Dynamically creates a new child of :class:`Function` using the - provided :func:`node_function` and returns an instance of that. - - Beyond the standard :class:`Function`, initialization allows the args... - - Args: - node_function (callable): The function determining the behaviour of the node. - output_labels (Optional[str | list[str] | tuple[str]]): A name for each return - value of the node function OR a single label. (Default is None, which - scrapes output labels automatically from the source code of the wrapped - function.) This can be useful when returned values are not well named, e.g. - to make the output channel dot-accessible if it would otherwise have a label - that requires item-string-based access. Additionally, specifying a _single_ - label for a wrapped function that returns a tuple of values ensures that a - _single_ output channel (holding the tuple) is created, instead of one - channel for each return value. The default approach of extracting labels - from the function source code also requires that the function body contain - _at most_ one `return` expression, so providing explicit labels can be used - to circumvent this (at your own risk), or to circumvent un-inspectable - source code (e.g. a function that exists only in memory). - """ + return ( + node_function.__name__, + (Function,), # Define parentage + { + "node_function": staticmethod(node_function), + "__module__": node_function.__module__, + "_output_labels": None if len(output_labels) == 0 else output_labels, + "_validate_output_labels": validate_output_labels, + }, + {}, + ) - if not callable(node_function): - raise AttributeError( - f"Expected `node_function` to be callable but got {node_function}" + +def as_function_node(*output_labels, validate_output_labels=True): + def decorator(node_function): + function_node_factory.clear(node_function.__name__) # Force a fresh class + factory_made = function_node_factory( + node_function, validate_output_labels, *output_labels ) + factory_made._class_returns_from_decorated_function = node_function + factory_made.preview_io() + return factory_made + + return decorator + +def function_node( + node_function, + *node_args, + output_labels=None, + validate_output_labels=True, + **node_kwargs, +): if output_labels is None: output_labels = () elif isinstance(output_labels, str): output_labels = (output_labels,) - - return as_function_node( - *output_labels, validate_output_labels=validate_output_labels - )(node_function)( - *args, - label=label, - parent=parent, - overwrite_save=overwrite_save, - run_after_init=run_after_init, - storage_backend=storage_backend, - save_after_run=save_after_run, - **kwargs, + function_node_factory.clear(node_function.__name__) # Force a fresh class + factory_made = function_node_factory( + node_function, validate_output_labels, *output_labels ) + factory_made.preview_io() + return factory_made(*node_args, **node_kwargs) diff --git a/pyiron_workflow/io_preview.py b/pyiron_workflow/io_preview.py index bc8a85e8..6f692077 100644 --- a/pyiron_workflow/io_preview.py +++ b/pyiron_workflow/io_preview.py @@ -18,7 +18,15 @@ from functools import lru_cache, wraps from textwrap import dedent from types import FunctionType -from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING +from typing import ( + Any, + ClassVar, + get_args, + get_type_hints, + Literal, + Optional, + TYPE_CHECKING, +) from pyiron_workflow.channels import InputData, NOT_DATA from pyiron_workflow.injection import OutputDataWithInjection, OutputsWithInjection @@ -128,11 +136,13 @@ class ScrapesIO(HasIOPreview, ABC): @classmethod @abstractmethod def _io_defining_function(cls) -> callable: - """Must return a static class method.""" + """Must return a static method.""" - _output_labels: tuple[str] | None = None # None: scrape them - _validate_output_labels: bool = True # True: validate against source code - _io_defining_function_uses_self: bool = False # False: use entire signature + _output_labels: ClassVar[tuple[str] | None] = None # None: scrape them + _validate_output_labels: ClassVar[bool] = True # True: validate against source code + _io_defining_function_uses_self: ClassVar[bool] = ( + False # False: use entire signature + ) @classmethod def _build_inputs_preview(cls): @@ -354,114 +364,3 @@ def inputs(self) -> Inputs: @property def outputs(self) -> OutputsWithInjection: return self._outputs - - -class DecoratedNode(StaticNode, ScrapesIO, ABC): - """ - A static node whose IO is defined by a function's information (and maybe output - labels). - """ - - -def decorated_node_decorator_factory( - parent_class: type[DecoratedNode], - io_static_method: callable, - decorator_docstring_additions: str = "", - **parent_class_attr_overrides, -): - """ - A decorator factory for building decorators to dynamically create new subclasses - of some subclass of :class:`DecoratedNode` using the function they decorate. - - New classes get their class name and module set using the decorated function's - name and module. - - Args: - parent_class (type[DecoratedNode]): The base class for the new node class. - io_static_method: The static method on the :param:`parent_class` which will - store the io-defining function the resulting decorator will decorate. - :param:`parent_class` must override :meth:`_io_defining_function` inherited - from :class:`DecoratedNode` to return this method. This allows - :param:`parent_class` classes to have unique names for their io-defining - functions. - decorator_docstring_additions (str): Any extra text to add between the main - body of the docstring and the arguments. - **parent_class_attr_overrides: Any additional attributes to pass to the new, - dynamically created class created by the resulting decorator. - - Returns: - (callable): A decorator that takes creates a new subclass of - :param:`parent_class` that uses the wrapped function as the return value of - :meth:`_io_defining_function` for the :class:`DecoratedNode` mixin. - """ - if getattr(parent_class, io_static_method.__name__) is not io_static_method: - raise ValueError( - f"{io_static_method.__name__} is not a method on {parent_class}" - ) - if not isinstance(io_static_method, FunctionType): - raise TypeError(f"{io_static_method.__name__} should be a static method") - - def as_decorated_node_decorator( - *output_labels: str, - validate_output_labels: bool = True, - ): - output_labels = None if len(output_labels) == 0 else output_labels - - @builds_class_io - def as_decorated_node(io_defining_function: callable): - if not callable(io_defining_function): - raise AttributeError( - f"Tried to create a new child class of {parent_class.__name__}, " - f"but got {io_defining_function} instead of a callable." - ) - - return type( - io_defining_function.__name__, - (parent_class,), # Define parentage - { - io_static_method.__name__: staticmethod(io_defining_function), - "__module__": io_defining_function.__module__, - "_output_labels": output_labels, - "_validate_output_labels": validate_output_labels, - **parent_class_attr_overrides, - }, - ) - - return as_decorated_node - - as_decorated_node_decorator.__doc__ = dedent( - f""" - A decorator for dynamically creating `{parent_class.__name__}` sub-classes by - wrapping a function as the `{io_static_method.__name__}`. - - The returned subclass uses the wrapped function (and optionally any provided - :param:`output_labels`) to specify its IO. - - {decorator_docstring_additions} - - Args: - *output_labels (str): A name for each return value of the graph creating - function. When empty, scrapes output labels automatically from the - source code of the wrapped function. This can be useful when returned - values are not well named, e.g. to make the output channel - dot-accessible if it would otherwise have a label that requires - item-string-based access. Additionally, specifying a _single_ label for - a wrapped function that returns a tuple of values ensures that a - _single_ output channel (holding the tuple) is created, instead of one - channel for each return value. The default approach of extracting - labels from the function source code also requires that the function - body contain _at most_ one `return` expression, so providing explicit - labels can be used to circumvent this (at your own risk). (Default is - empty, try to scrape labels from the source code of the wrapped - function.) - validate_output_labels (bool): Whether to compare the provided output labels - (if any) against the source code (if available). (Default is True.) - - Returns: - (callable[[callable], type[{parent_class.__name__}]]): A decorator that - transforms a function into a child class of `{parent_class.__name__}` - using the decorated function as - `{parent_class.__name__}.{io_static_method.__name__}`. - """ - ) - return as_decorated_node_decorator diff --git a/pyiron_workflow/job.py b/pyiron_workflow/job.py index 8a220d86..1b825dd2 100644 --- a/pyiron_workflow/job.py +++ b/pyiron_workflow/job.py @@ -20,10 +20,13 @@ from __future__ import annotations +import base64 import inspect import os import sys +import cloudpickle + from pyiron_base import TemplateJob, JOB_CLASS_DICT from pyiron_base.jobs.flex.pythonfunctioncontainer import ( PythonFunctionContainerJob, @@ -103,25 +106,51 @@ def validate_ready_to_run(self): f"Node not ready:{nl}{self.input['node'].readiness_report}" ) + def save(self): + # DataContainer can't handle custom reconstructors, so convert the node to + # bytestream + self.input["node"] = base64.b64encode( + cloudpickle.dumps(self.input["node"]) + ).decode("utf-8") + super().save() + def run_static(self): # Overrides the parent method # Copy and paste except for the output update, which makes sure the output is # flat and not tested beneath "result" + + # Unpack the node + input_dict = self.input.to_builtin() + input_dict["node"] = cloudpickle.loads(base64.b64decode(self.input["node"])) + if ( self._executor_type is not None and "executor" in inspect.signature(self._function).parameters.keys() ): - input_dict = self.input.to_builtin() del input_dict["executor"] output = self._function( **input_dict, executor=self._get_executor(max_workers=self.server.cores) ) else: - output = self._function(**self.input.to_builtin()) + output = self._function(**input_dict) self.output.update(output) # DIFFERS FROM PARENT METHOD self.to_hdf() self.status.finished = True + def get_input_node(self): + """ + On saving, we turn the input node into a bytestream so that the DataContainer + can store it. You might want to look at it again though, so you can use this + to unpack it + + Returns: + (Node): The input node as a node again + """ + if isinstance(self.input["node"], Node): + return self.input["node"] + else: + return cloudpickle.loads(base64.b64decode(self.input["node"])) + JOB_CLASS_DICT[NodeOutputJob.__name__] = NodeOutputJob.__module__ diff --git a/pyiron_workflow/macro.py b/pyiron_workflow/macro.py index e1798f7d..4c3727ba 100644 --- a/pyiron_workflow/macro.py +++ b/pyiron_workflow/macro.py @@ -12,13 +12,14 @@ from pyiron_workflow.composite import Composite from pyiron_workflow.has_interface_mixins import HasChannel from pyiron_workflow.io import Outputs, Inputs -from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory +from pyiron_workflow.io_preview import StaticNode, ScrapesIO +from pyiron_workflow.snippets.factory import classfactory if TYPE_CHECKING: from pyiron_workflow.channels import Channel -class Macro(Composite, DecoratedNode, ABC): +class Macro(Composite, StaticNode, ScrapesIO, ABC): """ A macro is a composite node that holds a graph with a fixed interface, like a pre-populated workflow that is the same every time you instantiate it. @@ -470,74 +471,123 @@ def __setstate__(self, state): self.children[child].outputs[child_out].value_receiver = self.outputs[out] -as_macro_node = decorated_node_decorator_factory( - Macro, - Macro.graph_creator, - decorator_docstring_additions="The first argument in the wrapped function is " - "`self`-like and will receive the macro instance " - "itself, and thus is ignored in the IO.", -) +@classfactory +def macro_node_factory( + graph_creator: callable, validate_output_labels: bool, /, *output_labels +): + return ( + graph_creator.__name__, + (Macro,), # Define parentage + { + "graph_creator": staticmethod(graph_creator), + "__module__": graph_creator.__module__, + "_output_labels": None if len(output_labels) == 0 else output_labels, + "_validate_output_labels": validate_output_labels, + }, + {}, + ) -def macro_node( - graph_creator, - label: Optional[str] = None, - parent: Optional[Composite] = None, - overwrite_save: bool = False, - run_after_init: bool = False, - storage_backend: Optional[Literal["h5io", "tinybase"]] = None, - save_after_run: bool = False, - strict_naming: bool = True, - output_labels: Optional[str | list[str] | tuple[str]] = None, - validate_output_labels: bool = True, - **kwargs, -): - """ - Creates a new child of :class:`Macro` using the provided - :func:`graph_creator` and returns an instance of that. - - Quacks like a :class:`Composite` for the sake of creating and registering nodes. - - Beyond the standard :class:`Macro`, initialization allows the args... - - Args: - graph_creator (callable): The function defining macro's graph. - output_labels (Optional[str | list[str] | tuple[str]]): A name for each return - value of the node function OR a single label. (Default is None, which - scrapes output labels automatically from the source code of the wrapped - function.) This can be useful when returned values are not well named, e.g. - to make the output channel dot-accessible if it would otherwise have a label - that requires item-string-based access. Additionally, specifying a _single_ - label for a wrapped function that returns a tuple of values ensures that a - _single_ output channel (holding the tuple) is created, instead of one - channel for each return value. The default approach of extracting labels - from the function source code also requires that the function body contain - _at most_ one `return` expression, so providing explicit labels can be used - to circumvent this (at your own risk), or to circumvent un-inspectable - source code (e.g. a function that exists only in memory). - """ - if not callable(graph_creator): - # `function_node` quacks like a class, even though it's a function and - # dynamically creates children of `Macro` by providing the necessary - # callable to the decorator - raise AttributeError( - f"Expected `graph_creator` to be callable but got {graph_creator}" +def as_macro_node(*output_labels, validate_output_labels=True): + def decorator(node_function): + macro_node_factory.clear(node_function.__name__) # Force a fresh class + factory_made = macro_node_factory( + node_function, validate_output_labels, *output_labels ) + factory_made._class_returns_from_decorated_function = node_function + factory_made.preview_io() + return factory_made + return decorator + + +def macro_node( + node_function, + *node_args, + output_labels=None, + validate_output_labels=True, + **node_kwargs, +): if output_labels is None: output_labels = () elif isinstance(output_labels, str): output_labels = (output_labels,) - - return as_macro_node(*output_labels, validate_output_labels=validate_output_labels)( - graph_creator - )( - label=label, - parent=parent, - overwrite_save=overwrite_save, - run_after_init=run_after_init, - storage_backend=storage_backend, - save_after_run=save_after_run, - strict_naming=strict_naming, - **kwargs, + macro_node_factory.clear(node_function.__name__) # Force a fresh class + factory_made = macro_node_factory( + node_function, validate_output_labels, *output_labels ) + factory_made.preview_io() + return factory_made(*node_args, **node_kwargs) + + +# as_macro_node = decorated_node_decorator_factory( +# Macro, +# Macro.graph_creator, +# decorator_docstring_additions="The first argument in the wrapped function is " +# "`self`-like and will receive the macro instance " +# "itself, and thus is ignored in the IO.", +# ) +# +# +# def macro_node( +# graph_creator, +# label: Optional[str] = None, +# parent: Optional[Composite] = None, +# overwrite_save: bool = False, +# run_after_init: bool = False, +# storage_backend: Optional[Literal["h5io", "tinybase"]] = None, +# save_after_run: bool = False, +# strict_naming: bool = True, +# output_labels: Optional[str | list[str] | tuple[str]] = None, +# validate_output_labels: bool = True, +# **kwargs, +# ): +# """ +# Creates a new child of :class:`Macro` using the provided +# :func:`graph_creator` and returns an instance of that. +# +# Quacks like a :class:`Composite` for the sake of creating and registering nodes. +# +# Beyond the standard :class:`Macro`, initialization allows the args... +# +# Args: +# graph_creator (callable): The function defining macro's graph. +# output_labels (Optional[str | list[str] | tuple[str]]): A name for each return +# value of the node function OR a single label. (Default is None, which +# scrapes output labels automatically from the source code of the wrapped +# function.) This can be useful when returned values are not well named, e.g. +# to make the output channel dot-accessible if it would otherwise have a label +# that requires item-string-based access. Additionally, specifying a _single_ +# label for a wrapped function that returns a tuple of values ensures that a +# _single_ output channel (holding the tuple) is created, instead of one +# channel for each return value. The default approach of extracting labels +# from the function source code also requires that the function body contain +# _at most_ one `return` expression, so providing explicit labels can be used +# to circumvent this (at your own risk), or to circumvent un-inspectable +# source code (e.g. a function that exists only in memory). +# """ +# if not callable(graph_creator): +# # `function_node` quacks like a class, even though it's a function and +# # dynamically creates children of `Macro` by providing the necessary +# # callable to the decorator +# raise AttributeError( +# f"Expected `graph_creator` to be callable but got {graph_creator}" +# ) +# +# if output_labels is None: +# output_labels = () +# elif isinstance(output_labels, str): +# output_labels = (output_labels,) +# +# return as_macro_node(*output_labels, validate_output_labels=validate_output_labels)( +# graph_creator +# )( +# label=label, +# parent=parent, +# overwrite_save=overwrite_save, +# run_after_init=run_after_init, +# storage_backend=storage_backend, +# save_after_run=save_after_run, +# strict_naming=strict_naming, +# **kwargs, +# ) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 3faba287..453fcb18 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -126,18 +126,19 @@ class Node( - Nodes created from a registered package store their package identifier as a class attribute. - [ALPHA FEATURE] Nodes can be saved to and loaded from file if python >= 3.11. + - As long as you haven't put anything unpickleable on them, or defined them in + an unpicklable place (e.g. in the `` of another function), you can + simple (un)pickle nodes. There is no save/load interface for this right + now, just import pickle and do it. - Saving is triggered manually, or by setting a flag to save after the nodes runs. - - On instantiation, nodes will load automatically if they find saved content. + - At the end of instantiation, nodes will load automatically if they find saved + content. - Discovered content can instead be deleted with a kwarg. - You can't load saved content _and_ run after instantiation at once. - - The nodes must be somewhere importable, and the imported object must match - the type of the node being saved. This basically just rules out one edge - case where a node class is defined like - `SomeFunctionNode = Workflow.wrap.as_function_node()(some_function)`, since - then the new class gets the name `some_function`, which when imported is - the _function_ "some_function" and not the desired class "SomeFunctionNode". - This is checked for at save-time and will cause a nice early failure. + - The nodes must be defined somewhere importable, i.e. in a module, `__main__`, + and as a class property are all fine, but, e.g., inside the `` of + another function is not. - [ALPHA ISSUE] If the source code (cells, `.py` files...) for a saved graph is altered between saving and loading the graph, there are no guarantees about the loaded state; depending on the nature of the changes everything may @@ -152,9 +153,14 @@ class Node( the entire graph may be saved at once. - [ALPHA ISSUE] There are two possible back-ends for saving: one leaning on `tinybase.storage.GenericStorage` (in practice, - `H5ioStorage(GenericStorage)`), and the other, default back-end that uses - the `h5io` module directly. The backend used is always the one on the graph - root. + `H5ioStorage(GenericStorage)`), that is the default, and the other that + uses the `h5io` module directly. The backend used is always the one on the + graph root. + - [ALPHA ISSUE] The `h5io` backend is deprecated -- it can't handle custom + reconstructors (i.e. when `__reduce__` returns a tuple with some + non-standard callable as its first entry), and basically all our nodes do + that now! `tinybase` gets around this by falling back on `cloudpickle` when + its own interactions with `h5io` fail. - [ALPHA ISSUE] Restrictions on data: - For the `h5io` backend: Most data that can be pickled will be fine, but some classes will hit an edge case and throw an exception from `h5io` diff --git a/pyiron_workflow/node_library/standard.py b/pyiron_workflow/node_library/standard.py index d091b8fe..8a3ab85c 100644 --- a/pyiron_workflow/node_library/standard.py +++ b/pyiron_workflow/node_library/standard.py @@ -5,6 +5,7 @@ from __future__ import annotations import random +from time import sleep from pyiron_workflow.channels import NOT_DATA, OutputSignal from pyiron_workflow.function import Function, as_function_node @@ -52,6 +53,12 @@ def RandomFloat(): return random.random() +@as_function_node("time") +def Sleep(t): + sleep(t) + return t + + @as_function_node("slice") def Slice(start=None, stop=NOT_DATA, step=None): if start is None: @@ -291,6 +298,7 @@ def Round(obj): RandomFloat, RightMultiply, Round, + Sleep, Slice, String, Subtract, diff --git a/pyiron_workflow/snippets/factory.py b/pyiron_workflow/snippets/factory.py index 69c8797d..475954d9 100644 --- a/pyiron_workflow/snippets/factory.py +++ b/pyiron_workflow/snippets/factory.py @@ -160,20 +160,38 @@ def __call__(self, *args) -> type[_FactoryMade]: return factory_made @classmethod - def clear(cls): + def clear(cls, *class_names, skip_missing=True): """ - Remove constructed classes. + Remove constructed class(es). Can be useful if you've updated the constructor and want to remove old instances. + + Args: + *class_names (str): The names of classes to remove. Removes all of them + when empty. + skip_missing (bool): Whether to pass over key errors when a name is + requested that is not currently in the class registry. (Default is + True, let missing names pass silently.) """ - cls.class_registry = {} + if len(class_names) == 0: + cls.class_registry = {} + else: + for name in class_names: + try: + cls.class_registry.pop(name) + except KeyError as e: + if skip_missing: + continue + else: + raise KeyError(f"Could not find class {name}") def _build_class( self, name, bases, class_dict, sc_init_kwargs, class_factory_args ) -> type[_FactoryMade]: - class_dict["__module__"] = self.factory_function.__module__ + if "__module__" not in class_dict.keys(): + class_dict["__module__"] = self.factory_function.__module__ sc_init_kwargs["class_factory"] = self sc_init_kwargs["class_factory_args"] = class_factory_args diff --git a/pyiron_workflow/storage.py b/pyiron_workflow/storage.py index 95a6a30c..bb93f227 100644 --- a/pyiron_workflow/storage.py +++ b/pyiron_workflow/storage.py @@ -372,10 +372,6 @@ def _storage_interfaces(cls): interfaces["h5io"] = H5ioStorage return interfaces - @classmethod - def default_backend(cls): - return "h5io" - class HasTinybaseStorage(HasStorage, ABC): @classmethod @@ -391,3 +387,7 @@ def to_storage(self, storage: TinybaseStorage): @abstractmethod def from_storage(self, storage: TinybaseStorage): pass + + @classmethod + def default_backend(cls): + return "tinybase" diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 1891b224..7d922d1f 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -1,3 +1,4 @@ +import pickle from typing import Optional, Union import unittest @@ -505,6 +506,16 @@ def NoReturn(x): # Honestly, functions with no return should probably be made illegal to # encourage functional setups... + def test_pickle(self): + n = function_node(plus_one, 5, output_labels="p1") + n() + reloaded = pickle.loads(pickle.dumps(n)) + self.assertListEqual(n.outputs.labels, reloaded.outputs.labels) + self.assertDictEqual( + n.outputs.to_value_dict(), + reloaded.outputs.to_value_dict() + ) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_io_preview.py b/tests/unit/test_io_preview.py index fa366028..d7cdddae 100644 --- a/tests/unit/test_io_preview.py +++ b/tests/unit/test_io_preview.py @@ -3,26 +3,52 @@ import unittest from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.io_preview import ( - ScrapesIO, decorated_node_decorator_factory, OutputLabelsNotValidated -) +from pyiron_workflow.io_preview import ScrapesIO, OutputLabelsNotValidated +from pyiron_workflow.snippets.factory import classfactory -class ScraperParent(ScrapesIO, ABC): - - @staticmethod - @abstractmethod - def io_function(*args, **kwargs): - pass - +class ScrapesFromDecorated(ScrapesIO): @classmethod - def _io_defining_function(cls): - return cls.io_function - - -as_scraper = decorated_node_decorator_factory( - ScraperParent, ScraperParent.io_function -) + def _io_defining_function(cls) -> callable: + return cls._decorated_function + + +@classfactory +def scraper_factory( + io_defining_function, + validate_output_labels, + io_defining_function_uses_self, + /, + *output_labels, +): + return ( + io_defining_function.__name__, + (ScrapesFromDecorated,), # Define parentage + { + "_decorated_function": staticmethod(io_defining_function), + "__module__": io_defining_function.__module__, + "_output_labels": None if len(output_labels) == 0 else output_labels, + "_validate_output_labels": validate_output_labels, + "_io_defining_function_uses_self": io_defining_function_uses_self + }, + {}, + ) + + +def as_scraper( + *output_labels, + validate_output_labels=True, + io_defining_function_uses_self=False, +): + def scraper_decorator(fnc): + scraper_factory.clear(fnc.__name__) # Force a fresh class + factory_made = scraper_factory( + fnc, validate_output_labels, io_defining_function_uses_self, *output_labels + ) + factory_made._class_returns_from_decorated_function = fnc + factory_made.preview_io() + return factory_made + return scraper_decorator class TestIOPreview(unittest.TestCase): diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index b1a93a63..9baffa0c 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -7,12 +7,7 @@ from pyiron_workflow import Workflow from pyiron_workflow.channels import NOT_DATA import pyiron_workflow.job # To get the job classes registered - - -@Workflow.wrap.as_function_node("t") -def Sleep(t): - sleep(t) - return t +from pyiron_workflow.node import Node class _WithAJob(unittest.TestCase, ABC): @@ -48,6 +43,9 @@ def test_clean_failure(self): def test_node(self): node = Workflow.create.standard.UserInput(42) nj = self.make_a_job_from_node(node) + + self.assertIsInstance(nj.get_input_node(), Node, msg="Sanity check") + nj.run() self.assertEqual( 42, @@ -55,10 +53,23 @@ def test_node(self): msg="A single node should run just as well as a workflow" ) + self.assertIsInstance( + nj.input["node"], + str, + msg="On saving, we convert the input to a bytestream so DataContainer can " + "handle storing it." + ) + self.assertIsInstance( + nj.get_input_node(), + Node, + msg="But we might want to look at it again, so make sure this convenience " + "method works." + ) + @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+") def test_modal(self): modal_wf = Workflow("modal_wf") - modal_wf.sleep = Sleep(0) + modal_wf.sleep = Workflow.create.standard.Sleep(0) modal_wf.out = modal_wf.create.standard.UserInput(modal_wf.sleep) nj = self.make_a_job_from_node(modal_wf) @@ -140,19 +151,13 @@ def not_importable_directy_from_module(x): return x + 1 nj = self.make_a_job_from_node(not_importable_directy_from_module(42)) - nj.run() - self.assertEqual( - 43, - nj.output.y, - msg="Things should run fine locally" - ) with self.assertRaises( AttributeError, - msg="We have promised that you'll hit trouble if you try to load a job " - "whose nodes are not all importable directly from their module" - # h5io also has this limitation, so I suspect that may be the source + msg="On saving we cloudpickle the node, then at run time we try to " + "recreate the dynamic class, but this doesn't work from the " + "scope, i.e. when the function is nested inside another function." ): - self.pr.load(nj.job_name) + nj.run() @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+") def test_shorter_name(self): @@ -181,43 +186,52 @@ def test_clean_failure(self): def test_node(self): node = Workflow.create.standard.UserInput(42) nj = self.make_a_job_from_node(node) - nj.run() - self.assertEqual( - 42, - nj.node.outputs.user_input.value, - msg="A single node should run just as well as a workflow" - ) + try: + nj.run() + self.assertEqual( + 42, + nj.node.outputs.user_input.value, + msg="A single node should run just as well as a workflow" + ) + finally: + try: + node.storage.delete() + except FileNotFoundError: + pass @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+") def test_modal(self): modal_wf = Workflow("modal_wf") - modal_wf.sleep = Sleep(0) + modal_wf.sleep = Workflow.create.standard.Sleep(0) modal_wf.out = modal_wf.create.standard.UserInput(modal_wf.sleep) nj = self.make_a_job_from_node(modal_wf) - nj.run() - self.assertTrue( - nj.status.finished, - msg="The interpreter should not release until the job is done" - ) - self.assertEqual( - 0, - nj.node.outputs.out__user_input.value, - msg="The node should have run, and since it's modal there's no need to " - "update the instance" - ) - - lj = self.pr.load(nj.job_name) - self.assertIsNot( - lj, - nj, - msg="The loaded job should be a new instance." - ) - self.assertEqual( - nj.node.outputs.out__user_input.value, - lj.node.outputs.out__user_input.value, - msg="The loaded job should still have all the same values" - ) + try: + nj.run() + self.assertTrue( + nj.status.finished, + msg="The interpreter should not release until the job is done" + ) + self.assertEqual( + 0, + nj.node.outputs.out__user_input.value, + msg="The node should have run, and since it's modal there's no need to " + "update the instance" + ) + + lj = self.pr.load(nj.job_name) + self.assertIsNot( + lj, + nj, + msg="The loaded job should be a new instance." + ) + self.assertEqual( + nj.node.outputs.out__user_input.value, + lj.node.outputs.out__user_input.value, + msg="The loaded job should still have all the same values" + ) + finally: + modal_wf.storage.delete() @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+") def test_nonmodal(self): @@ -225,31 +239,35 @@ def test_nonmodal(self): nonmodal_node.out = Workflow.create.standard.UserInput(42) nj = self.make_a_job_from_node(nonmodal_node) - nj.run(run_mode="non_modal") - self.assertFalse( - nj.status.finished, - msg=f"The local process should released immediately per non-modal " - f"style, but got status {nj.status}" - ) - while not nj.status.finished: - sleep(0.1) - self.assertTrue( - nj.status.finished, - msg="The job status should update on completion" - ) - self.assertIs( - nj.node.outputs.out__user_input.value, - NOT_DATA, - msg="As usual with remote processes, we expect to require a data read " - "before the local instance reflects its new state." - ) - lj = self.pr.load(nj.job_name) - self.assertEqual( - 42, - lj.node.outputs.out__user_input.value, - msg="The loaded job should have the finished values" - ) + try: + nj.run(run_mode="non_modal") + self.assertFalse( + nj.status.finished, + msg=f"The local process should released immediately per non-modal " + f"style, but got status {nj.status}" + ) + while not nj.status.finished: + sleep(0.1) + self.assertTrue( + nj.status.finished, + msg="The job status should update on completion" + ) + self.assertIs( + nj.node.outputs.out__user_input.value, + NOT_DATA, + msg="As usual with remote processes, we expect to require a data read " + "before the local instance reflects its new state." + ) + + lj = self.pr.load(nj.job_name) + self.assertEqual( + 42, + lj.node.outputs.out__user_input.value, + msg="The loaded job should have the finished values" + ) + finally: + nonmodal_node.storage.delete() @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+") def test_bad_workflow(self): diff --git a/tests/unit/test_macro.py b/tests/unit/test_macro.py index 27956ee7..d594e1d1 100644 --- a/tests/unit/test_macro.py +++ b/tests/unit/test_macro.py @@ -1,10 +1,9 @@ -import sys from concurrent.futures import Future - +import pickle +import sys from time import sleep import unittest - from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA from pyiron_workflow.function import function_node @@ -492,49 +491,48 @@ def test_storage_for_modified_macros(self): modified_result = macro() - macro.save() - reloaded = Macro.create.demo.AddThree( - label="m", storage_backend=backend - ) - self.assertDictEqual( - modified_result, - reloaded.outputs.to_value_dict(), - msg="Updated IO should have been (de)serialized" - ) - self.assertSetEqual( - set(macro.children.keys()), - set(reloaded.children.keys()), - msg="All nodes should have been (de)serialized." - ) # Note that this snags the _new_ one in the case of h5io! - self.assertEqual( - Macro.create.demo.AddThree.__name__, - reloaded.__class__.__name__, - msg=f"LOOK OUT! This all (de)serialized nicely, but what we " - f"loaded is _falsely_ claiming to be an " - f"{Macro.create.demo.AddThree.__name__}. This is " - f"not any sort of technical error -- what other class name " - f"would we load? -- but is a deeper problem with saving " - f"modified objects that we need ot figure out some better " - f"solution for later." - ) - rerun = reloaded() - if backend == "h5io": + with self.assertRaises( + TypeError, msg="h5io can't handle custom reconstructors" + ): + macro.save() + else: + macro.save() + reloaded = Macro.create.demo.AddThree( + label="m", storage_backend=backend + ) self.assertDictEqual( modified_result, - rerun, - msg="Rerunning should re-execute the _modified_ " - "functionality" + reloaded.outputs.to_value_dict(), + msg="Updated IO should have been (de)serialized" ) - elif backend == "tinybase": - self.assertDictEqual( - original_result, - rerun, - msg="Rerunning should re-execute the _original_ " - "functionality" + self.assertSetEqual( + set(macro.children.keys()), + set(reloaded.children.keys()), + msg="All nodes should have been (de)serialized." + ) # Note that this snags the _new_ one in the case of h5io! + self.assertEqual( + Macro.create.demo.AddThree.__name__, + reloaded.__class__.__name__, + msg=f"LOOK OUT! This all (de)serialized nicely, but what we " + f"loaded is _falsely_ claiming to be an " + f"{Macro.create.demo.AddThree.__name__}. This is " + f"not any sort of technical error -- what other class name " + f"would we load? -- but is a deeper problem with saving " + f"modified objects that we need ot figure out some better " + f"solution for later." ) - else: - raise ValueError(f"Unexpected backend {backend}?") + rerun = reloaded() + + if backend == "tinybase": + self.assertDictEqual( + original_result, + rerun, + msg="Rerunning should re-execute the _original_ " + "functionality" + ) + else: + raise ValueError(f"Unexpected backend {backend}?") finally: macro.storage.delete() @@ -561,6 +559,40 @@ def ReturnHasDot(macro): macro.foo = macro.create.standard.UserInput() return macro.foo.outputs.user_input + def test_pickle(self): + m = macro_node(add_three_macro) + m(1) + reloaded_m = pickle.loads(pickle.dumps(m)) + self.assertTupleEqual( + m.child_labels, + reloaded_m.child_labels, + msg="Spot check values are getting reloaded correctly" + ) + self.assertDictEqual( + m.outputs.to_value_dict(), + reloaded_m.outputs.to_value_dict(), + msg="Spot check values are getting reloaded correctly" + ) + self.assertTrue( + reloaded_m.two.connected, + msg="The macro should reload with all its child connections" + ) + + self.assertTrue(m.two.connected, msg="Sanity check") + reloaded_two = pickle.loads(pickle.dumps(m.two)) + self.assertFalse( + reloaded_two.connected, + msg="Children are expected to be de-parenting on serialization, so that if " + "we ship them off to another process, they don't drag their whole " + "graph with them" + ) + self.assertEqual( + m.two.outputs.to_value_dict(), + reloaded_two.outputs.to_value_dict(), + msg="The remainder of the child node state should be recovering just " + "fine on (de)serialization, this is a spot-check" + ) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 6a66d9db..91337964 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -1,4 +1,5 @@ from concurrent.futures import Future +import pickle import sys from time import sleep import unittest @@ -441,27 +442,34 @@ def add_three_macro(self, one__x): def test_storage_values(self): for backend in Workflow.allowed_backends(): with self.subTest(backend): - wf = Workflow("wf", storage_backend=backend) try: + print("Trying", backend) + wf = Workflow("wf", storage_backend=backend) wf.register("static.demo_nodes", domain="demo") wf.inp = wf.create.demo.AddThree(x=0) wf.out = wf.inp.outputs.add_three + 1 wf_out = wf() three_result = wf.inp.three.outputs.add.value - wf.save() - - reloaded = Workflow("wf", storage_backend=backend) - self.assertEqual( - wf_out.out__add, - reloaded.outputs.out__add.value, - msg="Workflow-level data should get reloaded" - ) - self.assertEqual( - three_result, - reloaded.inp.three.value, - msg="Child data arbitrarily deep should get reloaded" - ) + if backend == "h5io": + with self.assertRaises( + TypeError, + msg="h5io can't handle custom reconstructors" + ): + wf.save() + else: + wf.save() + reloaded = Workflow("wf", storage_backend=backend) + self.assertEqual( + wf_out.out__add, + reloaded.outputs.out__add.value, + msg="Workflow-level data should get reloaded" + ) + self.assertEqual( + three_result, + reloaded.inp.three.value, + msg="Child data arbitrarily deep should get reloaded" + ) finally: # Clean up after ourselves wf.storage.delete() @@ -479,9 +487,20 @@ def test_storage_scopes(self): for backend in Workflow.allowed_backends(): with self.subTest(backend): try: - wf.storage_backend = backend - wf.save() - Workflow(wf.label, storage_backend=backend) + for backend in Workflow.allowed_backends(): + if backend == "h5io": + with self.subTest(backend): + with self.assertRaises( + TypeError, + msg="h5io can't handle custom reconstructors" + ): + wf.storage_backend = backend + wf.save() + else: + with self.subTest(backend): + wf.storage_backend = backend + wf.save() + Workflow(wf.label, storage_backend=backend) finally: wf.storage.delete() @@ -499,24 +518,30 @@ def test_storage_scopes(self): wf.save() finally: wf.remove_child(wf.import_type_mismatch) + wf.storage.delete() if "h5io" in Workflow.allowed_backends(): wf.add_child(PlusOne(label="local_but_importable")) try: - wf.storage_backend = "h5io" - wf.save() - Workflow(wf.label, storage_backend="h5io") + with self.assertRaises( + TypeError, msg="h5io can't handle custom reconstructors" + ): + wf.storage_backend = "h5io" + wf.save() finally: wf.storage.delete() if "tinybase" in Workflow.allowed_backends(): - with self.assertRaises( - NotImplementedError, - msg="Storage docs for tinybase claim all children must be registered " - "nodes" - ): - wf.storage_backend = "tinybase" - wf.save() + try: + with self.assertRaises( + NotImplementedError, + msg="Storage docs for tinybase claim all children must be registered " + "nodes" + ): + wf.storage_backend = "tinybase" + wf.save() + finally: + wf.storage.delete() if "h5io" in Workflow.allowed_backends(): with self.subTest("Instanced node"): @@ -553,6 +578,19 @@ def UnimportableScope(x): wf.remove_child(wf.unimportable_scope) wf.storage.delete() + def test_pickle(self): + wf = Workflow("wf") + wf.register("static.demo_nodes", domain="demo") + wf.inp = wf.create.demo.AddThree(x=0) + wf.out = wf.inp.outputs.add_three + 1 + wf_out = wf() + reloaded = pickle.loads(pickle.dumps(wf)) + self.assertDictEqual( + wf_out, + reloaded.outputs.to_value_dict(), + msg="Pickling should work" + ) + if __name__ == '__main__': unittest.main()