Skip to content

Commit

Permalink
Merge pull request #53 from pyiron/fix_node_registration
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber authored Oct 31, 2023
2 parents 091abf1 + 23b8827 commit bb79722
Show file tree
Hide file tree
Showing 17 changed files with 436 additions and 124 deletions.
84 changes: 50 additions & 34 deletions notebooks/workflow_example.ipynb

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions pyiron_workflow/_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Tools specifically for the test suite, not intended for general use.
"""

from pathlib import Path
import sys


def ensure_tests_in_python_path():
"""So that you can import from the static module"""
path_to_tests = Path(__file__).parent.parent / "tests"
as_string = str(path_to_tests.resolve())

if as_string not in sys.path:
sys.path.append(as_string)
1 change: 1 addition & 0 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def to_dict(self) -> dict:
d = super().to_dict()
d["value"] = repr(self.value)
d["ready"] = self.ready
d["type_hint"] = str(self.type_hint)
return d


Expand Down
7 changes: 6 additions & 1 deletion pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import partial
from functools import partial, wraps
from typing import Literal, Optional, TYPE_CHECKING

from bidict import bidict
Expand Down Expand Up @@ -422,6 +422,11 @@ def replace(self, owned_node: Node | str, replacement: Node | type[Node]) -> Nod
self.starting_nodes.append(replacement)
return owned_node

@classmethod
@wraps(Creator.register)
def register(cls, domain: str, package_identifier: str) -> None:
cls.create.register(domain=domain, package_identifier=package_identifier)

def __setattr__(self, key: str, node: Node):
if isinstance(node, Node) and key != "parent":
self.add(node, label=key)
Expand Down
78 changes: 48 additions & 30 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def __init__(
else:
# If a callable node function is received, use it
self.node_function = node_function
self._type_hints = get_type_hints(node_function)

super().__init__(
label=label if label is not None else self.node_function.__name__,
Expand Down Expand Up @@ -382,7 +383,7 @@ def outputs(self) -> Outputs:

def _build_input_channels(self):
channels = []
type_hints = get_type_hints(self.node_function)
type_hints = self._type_hints

for ii, (label, value) in enumerate(self._input_args.items()):
is_self = False
Expand Down Expand Up @@ -435,7 +436,7 @@ def _init_keywords(self):

def _build_output_channels(self, *return_labels: str):
try:
type_hints = get_type_hints(self.node_function)["return"]
type_hints = self._type_hints["return"]
if len(return_labels) > 1:
type_hints = get_args(type_hints)
if not isinstance(type_hints, tuple):
Expand Down Expand Up @@ -607,7 +608,13 @@ def __getitem__(self, item):
return self.single_value.__getitem__(item)

def __getattr__(self, item):
return getattr(self.single_value, item)
try:
return getattr(self.single_value, item)
except Exception as e:
raise AttributeError(
f"Could not find {item} as an attribute of the single value "
f"{self.single_value}"
) from e

def __repr__(self):
return self.single_value.__repr__()
Expand All @@ -618,35 +625,61 @@ def __str__(self):
)


def function_node(output_labels=None):
def _wrapper_factory(
parent_class: type[Function], output_labels: Optional[list[str]]
) -> callable:
"""
A decorator for dynamically creating node classes from functions.
Decorates a function.
Returns a `Function` subclass whose name is the camel-case version of the function
node, and whose signature is modified to exclude the node function and output labels
(which are explicitly defined in the process of using the decorator).
Optionally takes any keyword arguments of `Function`.
An abstract base for making decorators that wrap a function as `Function` or its
children.
"""

# One really subtle thing is that we manually parse the function type hints right
# here and include these as a class-level attribute.
# This is because on (de)(cloud)pickling a function node, somehow the node function
# method attached to it gets its `__globals__` attribute changed; it retains stuff
# _inside_ the function, but loses imports it used from the _outside_ -- i.e. type
# hints! I (@liamhuber) don't deeply understand _why_ (de)pickling is modifying the
# __globals__ in this way, but the result is that type hints cannot be parsed after
# the change.
# The final piece of the puzzle here is that because the node function is a _class_
# level attribute, if you (de)pickle a node, _new_ instances of that node wind up
# having their node function's `__globals__` trimmed down in this way!
# So to keep the type hint parsing working, we snag and interpret all the type hints
# at wrapping time, when we are guaranteed to have all the globals available, and
# also slap them on as a class-level attribute. These get safely packed and returned
# when (de)pickling so we can keep processing type hints without trouble.
def as_node(node_function: callable):
return type(
node_function.__name__.title().replace("_", ""), # fnc_name to CamelCase
(Function,), # Define parentage
(parent_class,), # Define parentage
{
"__init__": partialmethod(
Function.__init__,
parent_class.__init__,
None,
output_labels=output_labels,
),
"node_function": staticmethod(node_function),
"_type_hints": get_type_hints(node_function),
},
)

return as_node


def function_node(output_labels=None):
"""
A decorator for dynamically creating node classes from functions.
Decorates a function.
Returns a `Function` subclass whose name is the camel-case version of the function
node, and whose signature is modified to exclude the node function and output labels
(which are explicitly defined in the process of using the decorator).
Optionally takes any keyword arguments of `Function`.
"""
return _wrapper_factory(parent_class=Function, output_labels=output_labels)


def single_value_node(output_labels=None):
"""
A decorator for dynamically creating fast node classes from functions.
Expand All @@ -655,19 +688,4 @@ def single_value_node(output_labels=None):
Optionally takes any keyword arguments of `SingleValueNode`.
"""

def as_single_value_node(node_function: callable):
return type(
node_function.__name__.title().replace("_", ""), # fnc_name to CamelCase
(SingleValue,), # Define parentage
{
"__init__": partialmethod(
SingleValue.__init__,
None,
output_labels=output_labels,
),
"node_function": staticmethod(node_function),
},
)

return as_single_value_node
return _wrapper_factory(parent_class=SingleValue, output_labels=output_labels)
146 changes: 118 additions & 28 deletions pyiron_workflow/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from importlib import import_module
from sys import version_info

from pyiron_base.interfaces.singleton import Singleton

Expand All @@ -18,9 +19,6 @@
single_value_node,
)

if TYPE_CHECKING:
from pyiron_workflow.node import Node


class Creator(metaclass=Singleton):
"""
Expand All @@ -30,6 +28,8 @@ class Creator(metaclass=Singleton):
"""

def __init__(self):
self._node_packages = {}

self.Executor = Executor

self.Function = Function
Expand All @@ -40,6 +40,13 @@ def __init__(self):
self._workflow = None
self._meta = None

if version_info[0] == 3 and version_info[1] >= 10:
# These modules use syntactic sugar for type hinting that is only supported
# in python >=3.10
# If the CI skips testing on 3.9 gets dropped, we can think about removing
# this if-clause and just letting users of python <3.10 hit an error.
self.register("standard", "pyiron_workflow.node_library.standard")

@property
def Macro(self):
if self._macro is None:
Expand All @@ -56,20 +63,6 @@ def Workflow(self):
self._workflow = Workflow
return self._workflow

@property
def standard(self):
from pyiron_workflow.node_package import NodePackage
from pyiron_workflow.node_library.standard import nodes

return NodePackage(*nodes)

@property
def atomistics(self):
from pyiron_workflow.node_package import NodePackage
from pyiron_workflow.node_library.atomistics import nodes

return NodePackage(*nodes)

@property
def meta(self):
if self._meta is None:
Expand All @@ -78,16 +71,113 @@ def meta(self):
self._meta = meta_nodes
return self._meta

def register(self, domain: str, *nodes: list[type[Node]]):
raise NotImplementedError(
"Registering new node packages is currently not playing well with "
"executors. We hope to return this feature soon."
)
# if domain in self.__dir__():
# raise AttributeError(f"{domain} is already an attribute of {self}")
# from pyiron_workflow.node_package import NodePackage
#
# setattr(self, domain, NodePackage(*nodes))
def __getattr__(self, item):
try:
module = import_module(self._node_packages[item])
from pyiron_workflow.node_package import NodePackage

return NodePackage(*module.nodes)
except KeyError as e:
raise AttributeError(
f"{self.__class__.__name__} could not find attribute {item} -- did you "
f"forget to register node package to this key?"
) from e

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__ = state

def register(self, domain: str, package_identifier: str) -> None:
"""
Add a new package of nodes under the provided attribute, e.g. after adding
nodes to the domain `"my_nodes"`, and instance of creator can call things like
`creator.my_nodes.some_node_that_is_there()`.
Note: If a macro is going to use a creator, the node registration should be
_inside_ the macro definition to make sure the node actually has access to
those nodes! It also needs to be _able_ to register those nodes, i.e. have
import access to that location, but we don't for that check that.
Args:
domain (str): The attribute name at which to register the new package.
(Note: no sanitizing is done here, so if you provide a string that
won't work as an attribute name, that's your problem.)
package_identifier (str): An identifier for the node package. (Right now
that's just a string version of the path to the module, e.g.
`pyiron_workflow.node_library.standard`.)
Raises:
KeyError: If the domain already exists, but the identifier doesn't match
with the stored identifier.
AttributeError: If you try to register at a domain that is already another
method or attribute of the creator.
ValueError: If the identifier can't be parsed.
"""

if self._package_conflicts_with_existing(domain, package_identifier):
raise KeyError(
f"{domain} is already a registered node package, please choose a "
f"different domain to store these nodes under"
)
elif domain in self.__dir__():
raise AttributeError(f"{domain} is already an attribute of {self}")

self._verify_identifier(package_identifier)

self._node_packages[domain] = package_identifier

def _package_conflicts_with_existing(
self, domain: str, package_identifier: str
) -> bool:
"""
Check if the new package conflict with an existing package at the requested
domain; if there isn't one, or if the new and old packages are identical then
there is no conflict!
Args:
domain (str): The domain at which the new package is attempting to register.
package_identifier (str): The identifier for the new package.
Returns:
(bool): True iff there is a package already at that domain and it is not
the same as the new one.
"""
if domain in self._node_packages.keys():
# If it's already here, it had better be the same package
return package_identifier != self._node_packages[domain]
# We can make "sameness" logic more complex as we allow more sophisticated
# identifiers
else:
# If it's not here already, it can't conflict!
return False

@staticmethod
def _verify_identifier(package_identifier: str):
"""
Logic for verifying whether new package identifiers will actually be usable for
creating node packages when their domain is called. Lets us fail early in
registration.
Right now, we just make sure it's a string from which we can import a list of
nodes.
"""
from pyiron_workflow.node import Node

try:
module = import_module(package_identifier)
nodes = module.nodes
if not all(issubclass(node, Node) for node in nodes):
raise TypeError(
f"At least one node in {nodes} was not of the type {Node.__name__}"
)
except Exception as e:
raise ValueError(
f"The package identifier is {package_identifier} is not valid. Please "
f"ensure it is an importable module with a list of {Node.__name__} "
f"objects stored in the variable `nodes`."
) from e


class Wrappers(metaclass=Singleton):
Expand Down
25 changes: 25 additions & 0 deletions pyiron_workflow/node_library/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
For graphical representations of data.
"""

from __future__ import annotations

from typing import Optional

import numpy as np

from pyiron_workflow.function import single_value_node


@single_value_node(output_labels="fig")
def scatter(
x: Optional[list | np.ndarray] = None, y: Optional[list | np.ndarray] = None
):
from matplotlib import pyplot as plt

return plt.scatter(x, y)


nodes = [
scatter,
]
Loading

0 comments on commit bb79722

Please sign in to comment.