Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[minor] Use factories in existing nodes #303

Merged
merged 19 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 42 additions & 58 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Literal, Optional, TYPE_CHECKING
from typing import Any

from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory
from pyiron_workflow.io_preview import StaticNode, ScrapesIO
from pyiron_workflow.snippets.colors import SeabornColors
from pyiron_workflow.snippets.factory import classfactory

if TYPE_CHECKING:
from pyiron_workflow.composite import Composite


class Function(DecoratedNode, ABC):
class Function(StaticNode, ScrapesIO, ABC):
"""
Function nodes wrap an arbitrary python function.

Expand Down Expand Up @@ -347,64 +345,50 @@ def color(self) -> str:
return SeabornColors.green


as_function_node = decorated_node_decorator_factory(Function, Function.node_function)


def function_node(
node_function: callable,
*args,
label: Optional[str] = None,
parent: Optional[Composite] = None,
overwrite_save: bool = False,
run_after_init: bool = False,
storage_backend: Optional[Literal["h5io", "tinybase"]] = None,
save_after_run: bool = False,
output_labels: Optional[str | tuple[str]] = None,
validate_output_labels: bool = True,
**kwargs,
@classfactory
def function_node_factory(
node_function: callable, validate_output_labels: bool, /, *output_labels
):
"""
Dynamically creates a new child of :class:`Function` using the
provided :func:`node_function` and returns an instance of that.

Beyond the standard :class:`Function`, initialization allows the args...

Args:
node_function (callable): The function determining the behaviour of the node.
output_labels (Optional[str | list[str] | tuple[str]]): A name for each return
value of the node function OR a single label. (Default is None, which
scrapes output labels automatically from the source code of the wrapped
function.) This can be useful when returned values are not well named, e.g.
to make the output channel dot-accessible if it would otherwise have a label
that requires item-string-based access. Additionally, specifying a _single_
label for a wrapped function that returns a tuple of values ensures that a
_single_ output channel (holding the tuple) is created, instead of one
channel for each return value. The default approach of extracting labels
from the function source code also requires that the function body contain
_at most_ one `return` expression, so providing explicit labels can be used
to circumvent this (at your own risk), or to circumvent un-inspectable
source code (e.g. a function that exists only in memory).
"""
return (
node_function.__name__,
(Function,), # Define parentage
{
"node_function": staticmethod(node_function),
"__module__": node_function.__module__,
"_output_labels": None if len(output_labels) == 0 else output_labels,
"_validate_output_labels": validate_output_labels,
},
{},
)

if not callable(node_function):
raise AttributeError(
f"Expected `node_function` to be callable but got {node_function}"

def as_function_node(*output_labels, validate_output_labels=True):
def decorator(node_function):
function_node_factory.clear(node_function.__name__) # Force a fresh class
factory_made = function_node_factory(
node_function, validate_output_labels, *output_labels
)
factory_made._class_returns_from_decorated_function = node_function
factory_made.preview_io()
return factory_made

return decorator


def function_node(
node_function,
*node_args,
output_labels=None,
validate_output_labels=True,
**node_kwargs,
):
if output_labels is None:
output_labels = ()
elif isinstance(output_labels, str):
output_labels = (output_labels,)

return as_function_node(
*output_labels, validate_output_labels=validate_output_labels
)(node_function)(
*args,
label=label,
parent=parent,
overwrite_save=overwrite_save,
run_after_init=run_after_init,
storage_backend=storage_backend,
save_after_run=save_after_run,
**kwargs,
function_node_factory.clear(node_function.__name__) # Force a fresh class
factory_made = function_node_factory(
node_function, validate_output_labels, *output_labels
)
factory_made.preview_io()
return factory_made(*node_args, **node_kwargs)
131 changes: 15 additions & 116 deletions pyiron_workflow/io_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@
from functools import lru_cache, wraps
from textwrap import dedent
from types import FunctionType
from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING
from typing import (
Any,
ClassVar,
get_args,
get_type_hints,
Literal,
Optional,
TYPE_CHECKING,
)

from pyiron_workflow.channels import InputData, NOT_DATA
from pyiron_workflow.injection import OutputDataWithInjection, OutputsWithInjection
Expand Down Expand Up @@ -128,11 +136,13 @@ class ScrapesIO(HasIOPreview, ABC):
@classmethod
@abstractmethod
def _io_defining_function(cls) -> callable:
"""Must return a static class method."""
"""Must return a static method."""

_output_labels: tuple[str] | None = None # None: scrape them
_validate_output_labels: bool = True # True: validate against source code
_io_defining_function_uses_self: bool = False # False: use entire signature
_output_labels: ClassVar[tuple[str] | None] = None # None: scrape them
_validate_output_labels: ClassVar[bool] = True # True: validate against source code
_io_defining_function_uses_self: ClassVar[bool] = (
False # False: use entire signature
)

@classmethod
def _build_inputs_preview(cls):
Expand Down Expand Up @@ -354,114 +364,3 @@ def inputs(self) -> Inputs:
@property
def outputs(self) -> OutputsWithInjection:
return self._outputs


class DecoratedNode(StaticNode, ScrapesIO, ABC):
"""
A static node whose IO is defined by a function's information (and maybe output
labels).
"""


def decorated_node_decorator_factory(
parent_class: type[DecoratedNode],
io_static_method: callable,
decorator_docstring_additions: str = "",
**parent_class_attr_overrides,
):
"""
A decorator factory for building decorators to dynamically create new subclasses
of some subclass of :class:`DecoratedNode` using the function they decorate.

New classes get their class name and module set using the decorated function's
name and module.

Args:
parent_class (type[DecoratedNode]): The base class for the new node class.
io_static_method: The static method on the :param:`parent_class` which will
store the io-defining function the resulting decorator will decorate.
:param:`parent_class` must override :meth:`_io_defining_function` inherited
from :class:`DecoratedNode` to return this method. This allows
:param:`parent_class` classes to have unique names for their io-defining
functions.
decorator_docstring_additions (str): Any extra text to add between the main
body of the docstring and the arguments.
**parent_class_attr_overrides: Any additional attributes to pass to the new,
dynamically created class created by the resulting decorator.

Returns:
(callable): A decorator that takes creates a new subclass of
:param:`parent_class` that uses the wrapped function as the return value of
:meth:`_io_defining_function` for the :class:`DecoratedNode` mixin.
"""
if getattr(parent_class, io_static_method.__name__) is not io_static_method:
raise ValueError(
f"{io_static_method.__name__} is not a method on {parent_class}"
)
if not isinstance(io_static_method, FunctionType):
raise TypeError(f"{io_static_method.__name__} should be a static method")

def as_decorated_node_decorator(
*output_labels: str,
validate_output_labels: bool = True,
):
output_labels = None if len(output_labels) == 0 else output_labels

@builds_class_io
def as_decorated_node(io_defining_function: callable):
if not callable(io_defining_function):
raise AttributeError(
f"Tried to create a new child class of {parent_class.__name__}, "
f"but got {io_defining_function} instead of a callable."
)

return type(
io_defining_function.__name__,
(parent_class,), # Define parentage
{
io_static_method.__name__: staticmethod(io_defining_function),
"__module__": io_defining_function.__module__,
"_output_labels": output_labels,
"_validate_output_labels": validate_output_labels,
**parent_class_attr_overrides,
},
)

return as_decorated_node

as_decorated_node_decorator.__doc__ = dedent(
f"""
A decorator for dynamically creating `{parent_class.__name__}` sub-classes by
wrapping a function as the `{io_static_method.__name__}`.

The returned subclass uses the wrapped function (and optionally any provided
:param:`output_labels`) to specify its IO.

{decorator_docstring_additions}

Args:
*output_labels (str): A name for each return value of the graph creating
function. When empty, scrapes output labels automatically from the
source code of the wrapped function. This can be useful when returned
values are not well named, e.g. to make the output channel
dot-accessible if it would otherwise have a label that requires
item-string-based access. Additionally, specifying a _single_ label for
a wrapped function that returns a tuple of values ensures that a
_single_ output channel (holding the tuple) is created, instead of one
channel for each return value. The default approach of extracting
labels from the function source code also requires that the function
body contain _at most_ one `return` expression, so providing explicit
labels can be used to circumvent this (at your own risk). (Default is
empty, try to scrape labels from the source code of the wrapped
function.)
validate_output_labels (bool): Whether to compare the provided output labels
(if any) against the source code (if available). (Default is True.)

Returns:
(callable[[callable], type[{parent_class.__name__}]]): A decorator that
transforms a function into a child class of `{parent_class.__name__}`
using the decorated function as
`{parent_class.__name__}.{io_static_method.__name__}`.
"""
)
return as_decorated_node_decorator
33 changes: 31 additions & 2 deletions pyiron_workflow/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from __future__ import annotations

import base64
import inspect
import os
import sys

import cloudpickle

from pyiron_base import TemplateJob, JOB_CLASS_DICT
from pyiron_base.jobs.flex.pythonfunctioncontainer import (
PythonFunctionContainerJob,
Expand Down Expand Up @@ -103,25 +106,51 @@ def validate_ready_to_run(self):
f"Node not ready:{nl}{self.input['node'].readiness_report}"
)

def save(self):
# DataContainer can't handle custom reconstructors, so convert the node to
# bytestream
self.input["node"] = base64.b64encode(
cloudpickle.dumps(self.input["node"])
).decode("utf-8")
super().save()

def run_static(self):
# Overrides the parent method
# Copy and paste except for the output update, which makes sure the output is
# flat and not tested beneath "result"

# Unpack the node
input_dict = self.input.to_builtin()
input_dict["node"] = cloudpickle.loads(base64.b64decode(self.input["node"]))

if (
self._executor_type is not None
and "executor" in inspect.signature(self._function).parameters.keys()
):
input_dict = self.input.to_builtin()
del input_dict["executor"]
output = self._function(
**input_dict, executor=self._get_executor(max_workers=self.server.cores)
)
else:
output = self._function(**self.input.to_builtin())
output = self._function(**input_dict)
self.output.update(output) # DIFFERS FROM PARENT METHOD
self.to_hdf()
self.status.finished = True

def get_input_node(self):
"""
On saving, we turn the input node into a bytestream so that the DataContainer
can store it. You might want to look at it again though, so you can use this
to unpack it

Returns:
(Node): The input node as a node again
"""
if isinstance(self.input["node"], Node):
return self.input["node"]
else:
return cloudpickle.loads(base64.b64decode(self.input["node"]))


JOB_CLASS_DICT[NodeOutputJob.__name__] = NodeOutputJob.__module__

Expand Down
Loading
Loading