Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[minor] Make Function IO info available at the class level #266

Merged
merged 12 commits into from
Apr 11, 2024
216 changes: 122 additions & 94 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from abc import ABC, abstractmethod
import inspect
import warnings
from functools import partialmethod
from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING

from pyiron_workflow.channels import InputData, OutputData, NOT_DATA
from pyiron_workflow.has_interface_mixins import HasChannel
from pyiron_workflow.channels import InputData, NOT_DATA
from pyiron_workflow.injection import OutputDataWithInjection
from pyiron_workflow.io import Inputs, Outputs
from pyiron_workflow.node import Node
Expand Down Expand Up @@ -142,7 +140,7 @@ class AbstractFunction(Node, ABC):
variety of common use cases.
Note that getting "good" (i.e. dot-accessible) output labels can be achieved by
using good variable names and returning those variables instead of using
:attr:`output_labels`.
:param:`output_labels`.
If we try to assign a value of the wrong type, it will raise an error:

>>> from typing import Union
Expand Down Expand Up @@ -196,11 +194,13 @@ class AbstractFunction(Node, ABC):
both that you are likely to have particular nodes that get heavily re-used, and
that you need the nodes to pass data to each other.

For reusable nodes, we want to create a sub-class of :class:`Function` that fixes some
of the node behaviour -- usually the :meth:`node_function` and :attr:`output_labels`.
For reusable nodes, we want to create a sub-class of :class:`AbstractFunction`
that fixes some of the node behaviour -- i.e. the :meth:`node_function`.

This can be done most easily with the :func:`function_node` decorator, which takes a function
and returns a node class:
This can be done most easily with the :func:`function_node` decorator, which
takes a function and returns a node class. It also allows us to provide labels
for the return values, :param:output_labels, which are otherwise scraped from
the text of the function definition:

>>> from pyiron_workflow.function import function_node
>>>
Expand Down Expand Up @@ -308,6 +308,8 @@ class AbstractFunction(Node, ABC):
guaranteed.
"""

_provided_output_labels: tuple[str] | None = None

def __init__(
self,
*args,
Expand All @@ -317,7 +319,6 @@ def __init__(
run_after_init: bool = False,
storage_backend: Optional[Literal["h5io", "tinybase"]] = None,
save_after_run: bool = False,
output_labels: Optional[str | list[str] | tuple[str]] = None,
**kwargs,
):
super().__init__(
Expand All @@ -330,8 +331,6 @@ def __init__(

self._inputs = None
self._outputs = None
self._output_labels = self._get_output_labels(output_labels)
# TODO: Parse output labels from the node function in case output_labels is None

self.set_input_values(*args, **kwargs)

Expand All @@ -345,54 +344,98 @@ def _type_hints(cls) -> dict:
"""The result of :func:`typing.get_type_hints` on the :meth:`node_function`."""
return get_type_hints(cls.node_function)

def _get_output_labels(self, output_labels: str | list[str] | tuple[str] | None):
@classmethod
def preview_output_channels(cls) -> dict[str, Any]:
"""
Gives a class-level peek at the expected output channels.

Returns:
dict[str, tuple[Any, Any]]: The channel name and its corresponding type
hint.
"""
labels = cls._get_output_labels()
try:
type_hints = cls._type_hints()["return"]
if len(labels) > 1:
type_hints = get_args(type_hints)
if not isinstance(type_hints, tuple):
raise TypeError(
f"With multiple return labels expected to get a tuple of type "
f"hints, but got type {type(type_hints)}"
)
if len(type_hints) != len(labels):
raise ValueError(
f"Expected type hints and return labels to have matching "
f"lengths, but got {len(type_hints)} hints and "
f"{len(labels)} labels: {type_hints}, {labels}"
)
else:
# If there's only one hint, wrap it in a tuple, so we can zip it with
# *return_labels and iterate over both at once
type_hints = (type_hints,)
except KeyError: # If there are no return hints
type_hints = [None] * len(labels)
# Note that this nicely differs from `NoneType`, which is the hint when
# `None` is actually the hint!
return {label: hint for label, hint in zip(labels, type_hints)}

@classmethod
def _get_output_labels(cls):
"""
If output labels are provided, turn convert them to a list if passed as a
string and return them, else scrape them from the source channel.
Return output labels provided on the class if not None, else scrape them from
:meth:`node_function`.

Note: When the user explicitly provides output channels, they are taking
responsibility that these are correct, e.g. in terms of quantity, order, etc.
"""
if output_labels is None:
return self._scrape_output_labels()
elif isinstance(output_labels, str):
return [output_labels]
if cls._provided_output_labels is None:
return cls._scrape_output_labels()
else:
return output_labels
return cls._provided_output_labels

def _scrape_output_labels(self):
@classmethod
def _scrape_output_labels(cls):
"""
Inspect the source code to scrape out strings representing the returned values.
_Only_ works for functions with a single `return` expression in their body.
Inspect :meth:`node_function` to scrape out strings representing the
returned values.

Will return expressions and function calls just fine, thus best practice is to
create well-named variables and return those so that the output labels stay
_Only_ works for functions with a single `return` expression in their body.

It will return expressions and function calls just fine, thus good practice is
to create well-named variables and return those so that the output labels stay
dot-accessible.
"""
parsed_outputs = ParseOutput(self.node_function).output
parsed_outputs = ParseOutput(cls.node_function).output
return [None] if parsed_outputs is None else parsed_outputs

@property
def _input_args(self):
return inspect.signature(self.node_function).parameters

@property
def inputs(self) -> Inputs:
if self._inputs is None:
self._inputs = Inputs(*self._build_input_channels())
return self._inputs

@property
def outputs(self) -> Outputs:
if self._outputs is None:
self._outputs = Outputs(*self._build_output_channels(*self._output_labels))
self._outputs = Outputs(*self._build_output_channels())
return self._outputs

def _build_input_channels(self):
channels = []
type_hints = self._type_hints()
def _build_output_channels(self):
return [
OutputDataWithInjection(
label=label,
owner=self,
type_hint=hint,
)
for label, hint in self.preview_output_channels().items()
]

for ii, (label, value) in enumerate(self._input_args.items()):
@classmethod
def preview_input_channels(cls) -> dict[str, tuple[Any, Any]]:
"""
Gives a class-level peek at the expected input channels.

Returns:
dict[str, tuple[Any, Any]]: The channel name and a tuple of its
corresponding type hint and default value.
"""
type_hints = cls._type_hints()
scraped: dict[str, tuple[Any, Any]] = {}
for ii, (label, value) in enumerate(cls._input_args().items()):
is_self = False
if label == "self": # `self` is reserved for the node object
if ii == 0:
Expand All @@ -404,12 +447,12 @@ def _build_input_channels(self):
" argument. If it is to be treated as the node object,"
" use it as a first argument"
)
if label in self._init_keywords:
elif label in cls._init_keywords():
# We allow users to parse arbitrary kwargs as channel initialization
# So don't let them choose bad channel names
raise ValueError(
f"The Input channel name {label} is not valid. Please choose a "
f"name _not_ among {self._init_keywords}"
f"name _not_ among {cls._init_keywords()}"
)

try:
Expand All @@ -427,54 +470,33 @@ def _build_input_channels(self):
default = value.default

if not is_self:
channels.append(
InputData(
label=label,
owner=self,
default=default,
type_hint=type_hint,
)
)
return channels
scraped[label] = (type_hint, default)
return scraped

@classmethod
def _input_args(cls):
return inspect.signature(cls.node_function).parameters

@classmethod
def _init_keywords(cls):
return list(inspect.signature(cls.__init__).parameters.keys())

@property
def _init_keywords(self):
return list(inspect.signature(self.__init__).parameters.keys())
def inputs(self) -> Inputs:
if self._inputs is None:
self._inputs = Inputs(*self._build_input_channels())
return self._inputs

def _build_output_channels(self, *return_labels: str):
try:
type_hints = self._type_hints()["return"]
if len(return_labels) > 1:
type_hints = get_args(type_hints)
if not isinstance(type_hints, tuple):
raise TypeError(
f"With multiple return labels expected to get a tuple of type "
f"hints, but got type {type(type_hints)}"
)
if len(type_hints) != len(return_labels):
raise ValueError(
f"Expected type hints and return labels to have matching "
f"lengths, but got {len(type_hints)} hints and "
f"{len(return_labels)} labels: {type_hints}, {return_labels}"
)
else:
# If there's only one hint, wrap it in a tuple so we can zip it with
# *return_labels and iterate over both at once
type_hints = (type_hints,)
except KeyError:
type_hints = [None] * len(return_labels)

channels = []
for label, hint in zip(return_labels, type_hints):
channels.append(
OutputDataWithInjection(
label=label,
owner=self,
type_hint=hint,
)
def _build_input_channels(self):
return [
InputData(
label=label,
owner=self,
default=default,
type_hint=type_hint,
)

return channels
for label, (type_hint, default) in self.preview_input_channels().items()
]

@property
def on_run(self):
Expand All @@ -483,7 +505,7 @@ def on_run(self):
@property
def run_args(self) -> dict:
kwargs = self.inputs.to_value_dict()
if "self" in self._input_args:
if "self" in self._input_args():
if self.executor:
raise ValueError(
f"Function node {self.label} uses the `self` argument, but this "
Expand All @@ -504,7 +526,7 @@ def process_run_result(self, function_output: Any | tuple) -> Any | tuple:
return function_output

def _convert_input_args_and_kwargs_to_input_kwargs(self, *args, **kwargs):
reverse_keys = list(self._input_args.keys())[::-1]
reverse_keys = list(self._input_args().keys())[::-1]
if len(args) > len(reverse_keys):
raise ValueError(
f"Received {len(args)} positional arguments, but the node {self.label}"
Expand Down Expand Up @@ -603,7 +625,7 @@ def __new__(
run_after_init: bool = False,
storage_backend: Optional[Literal["h5io", "tinybase"]] = None,
save_after_run: bool = False,
output_labels: Optional[str | list[str] | tuple[str]] = None,
output_labels: Optional[str | tuple[str]] = None,
**kwargs,
):
if not callable(node_function):
Expand Down Expand Up @@ -670,17 +692,23 @@ def function_node(*output_labels: str):
# also slap them on as a class-level attribute. These get safely packed and returned
# when (de)pickling so we can keep processing type hints without trouble.
def as_node(node_function: callable):
return type(
node_class = type(
node_function.__name__,
(AbstractFunction,), # Define parentage
{
"__init__": partialmethod(
AbstractFunction.__init__,
output_labels=output_labels,
),
"node_function": staticmethod(node_function),
"_provided_output_labels": output_labels,
"__module__": node_function.__module__,
},
)
try:
node_class.preview_output_channels()
except ValueError as e:
raise ValueError(
f"Failed to create a new {AbstractFunction.__name__} child class "
f"dynamically from {node_function.__name__} -- probably due to a "
f"mismatch among output labels, returned values, and return type hints."
) from e
return node_class

return as_node
Loading
Loading