Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node registration #53

Merged
merged 40 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9436f82
Raise from the original error
liamhuber Oct 26, 2023
9630641
Add the type hint to channel presentation
liamhuber Oct 26, 2023
359ad26
Store the type hints right on the decorator-created classes
liamhuber Oct 26, 2023
5dffb5b
Store a reference to the import path
liamhuber Oct 26, 2023
f551f59
Pre-register the normal packages
liamhuber Oct 26, 2023
7ff73d7
Handle the absence of return hints more gracefully
liamhuber Oct 26, 2023
2960d6f
Refactor the function node decorators
liamhuber Oct 26, 2023
b607c87
Format black
pyiron-runner Oct 26, 2023
8040316
Add a unit test for registration
liamhuber Oct 27, 2023
cb43cc8
Refactor tests
liamhuber Oct 27, 2023
b1dda61
Allow re-registering the _same thing_ to the same place
liamhuber Oct 27, 2023
7b9837d
Expand tests
liamhuber Oct 27, 2023
4dbcf22
Merge remote-tracking branch 'origin/fix_node_registration' into fix_…
liamhuber Oct 27, 2023
7851e7b
Refactor: rename
liamhuber Oct 27, 2023
63f30c0
Add docstring
liamhuber Oct 27, 2023
0328c22
Fail early if the provided package identifier will fail
liamhuber Oct 27, 2023
f43958e
Refactor: extract the "just re-registering" logic
liamhuber Oct 27, 2023
3136905
Format black
pyiron-runner Oct 27, 2023
20fa4ae
Version guard package registration
liamhuber Oct 30, 2023
08d3dc0
Explain intent for future devs
liamhuber Oct 30, 2023
8059b57
Be more generous waiting for the timeout
liamhuber Oct 30, 2023
d058cd0
Give a shortcut to node registration
liamhuber Oct 30, 2023
39dec3f
Add an integration test for the thing that originally hurt us
liamhuber Oct 30, 2023
4289210
Format black
pyiron-runner Oct 30, 2023
bdcb100
Make standard node package really standard
liamhuber Oct 30, 2023
b00aedc
Don't register atomistics by default
liamhuber Oct 30, 2023
9e45028
Purge atomistics module from unit tests
liamhuber Oct 30, 2023
0ffd8c9
Give the demo nodes complex typing
liamhuber Oct 30, 2023
3c2433b
Make it easier to ensure you can import the static demo nodes
liamhuber Oct 30, 2023
33474ed
Use a node name that doesn't conflict with Composite methods
liamhuber Oct 30, 2023
44e3a07
Use demo instead of atomistics nodes
liamhuber Oct 30, 2023
19110a9
Make the registration shortcut a class method
liamhuber Oct 30, 2023
664b06b
Use the demo nodes instead of atomistics
liamhuber Oct 30, 2023
6e4de97
Merge remote-tracking branch 'origin/fix_node_registration' into fix_…
liamhuber Oct 30, 2023
7e20641
:bug: fix node name typo
liamhuber Oct 30, 2023
6189e02
:bug: when you fix it use the right damned name
liamhuber Oct 30, 2023
8722e33
PEP8 whitespace
liamhuber Oct 30, 2023
00f8690
Update and rerun example notebook
liamhuber Oct 30, 2023
d8695ae
Format black
pyiron-runner Oct 30, 2023
23b8827
Add a note on registration and macros
liamhuber Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 50 additions & 34 deletions notebooks/workflow_example.ipynb

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions pyiron_workflow/_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Tools specifically for the test suite, not intended for general use.
"""

from pathlib import Path
import sys


def ensure_tests_in_python_path():
"""So that you can import from the static module"""
path_to_tests = Path(__file__).parent.parent / "tests"
as_string = str(path_to_tests.resolve())

if as_string not in sys.path:
sys.path.append(as_string)
1 change: 1 addition & 0 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def to_dict(self) -> dict:
d = super().to_dict()
d["value"] = repr(self.value)
d["ready"] = self.ready
d["type_hint"] = str(self.type_hint)
return d


Expand Down
7 changes: 6 additions & 1 deletion pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import partial
from functools import partial, wraps
from typing import Literal, Optional, TYPE_CHECKING

from bidict import bidict
Expand Down Expand Up @@ -422,6 +422,11 @@ def replace(self, owned_node: Node | str, replacement: Node | type[Node]) -> Nod
self.starting_nodes.append(replacement)
return owned_node

@classmethod
@wraps(Creator.register)
def register(cls, domain: str, package_identifier: str) -> None:
cls.create.register(domain=domain, package_identifier=package_identifier)

def __setattr__(self, key: str, node: Node):
if isinstance(node, Node) and key != "parent":
self.add(node, label=key)
Expand Down
78 changes: 48 additions & 30 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def __init__(
else:
# If a callable node function is received, use it
self.node_function = node_function
self._type_hints = get_type_hints(node_function)

super().__init__(
label=label if label is not None else self.node_function.__name__,
Expand Down Expand Up @@ -382,7 +383,7 @@ def outputs(self) -> Outputs:

def _build_input_channels(self):
channels = []
type_hints = get_type_hints(self.node_function)
type_hints = self._type_hints

for ii, (label, value) in enumerate(self._input_args.items()):
is_self = False
Expand Down Expand Up @@ -435,7 +436,7 @@ def _init_keywords(self):

def _build_output_channels(self, *return_labels: str):
try:
type_hints = get_type_hints(self.node_function)["return"]
type_hints = self._type_hints["return"]
if len(return_labels) > 1:
type_hints = get_args(type_hints)
if not isinstance(type_hints, tuple):
Expand Down Expand Up @@ -607,7 +608,13 @@ def __getitem__(self, item):
return self.single_value.__getitem__(item)

def __getattr__(self, item):
return getattr(self.single_value, item)
try:
return getattr(self.single_value, item)
except Exception as e:
raise AttributeError(
f"Could not find {item} as an attribute of the single value "
f"{self.single_value}"
) from e

def __repr__(self):
return self.single_value.__repr__()
Expand All @@ -618,35 +625,61 @@ def __str__(self):
)


def function_node(output_labels=None):
def _wrapper_factory(
parent_class: type[Function], output_labels: Optional[list[str]]
) -> callable:
"""
A decorator for dynamically creating node classes from functions.

Decorates a function.
Returns a `Function` subclass whose name is the camel-case version of the function
node, and whose signature is modified to exclude the node function and output labels
(which are explicitly defined in the process of using the decorator).

Optionally takes any keyword arguments of `Function`.
An abstract base for making decorators that wrap a function as `Function` or its
children.
"""

# One really subtle thing is that we manually parse the function type hints right
# here and include these as a class-level attribute.
# This is because on (de)(cloud)pickling a function node, somehow the node function
# method attached to it gets its `__globals__` attribute changed; it retains stuff
# _inside_ the function, but loses imports it used from the _outside_ -- i.e. type
# hints! I (@liamhuber) don't deeply understand _why_ (de)pickling is modifying the
# __globals__ in this way, but the result is that type hints cannot be parsed after
# the change.
# The final piece of the puzzle here is that because the node function is a _class_
# level attribute, if you (de)pickle a node, _new_ instances of that node wind up
# having their node function's `__globals__` trimmed down in this way!
# So to keep the type hint parsing working, we snag and interpret all the type hints
# at wrapping time, when we are guaranteed to have all the globals available, and
# also slap them on as a class-level attribute. These get safely packed and returned
# when (de)pickling so we can keep processing type hints without trouble.
def as_node(node_function: callable):
return type(
node_function.__name__.title().replace("_", ""), # fnc_name to CamelCase
(Function,), # Define parentage
(parent_class,), # Define parentage
{
"__init__": partialmethod(
Function.__init__,
parent_class.__init__,
None,
output_labels=output_labels,
),
"node_function": staticmethod(node_function),
"_type_hints": get_type_hints(node_function),
},
)

return as_node


def function_node(output_labels=None):
"""
A decorator for dynamically creating node classes from functions.

Decorates a function.
Returns a `Function` subclass whose name is the camel-case version of the function
node, and whose signature is modified to exclude the node function and output labels
(which are explicitly defined in the process of using the decorator).

Optionally takes any keyword arguments of `Function`.
"""
return _wrapper_factory(parent_class=Function, output_labels=output_labels)


def single_value_node(output_labels=None):
"""
A decorator for dynamically creating fast node classes from functions.
Expand All @@ -655,19 +688,4 @@ def single_value_node(output_labels=None):

Optionally takes any keyword arguments of `SingleValueNode`.
"""

def as_single_value_node(node_function: callable):
return type(
node_function.__name__.title().replace("_", ""), # fnc_name to CamelCase
(SingleValue,), # Define parentage
{
"__init__": partialmethod(
SingleValue.__init__,
None,
output_labels=output_labels,
),
"node_function": staticmethod(node_function),
},
)

return as_single_value_node
return _wrapper_factory(parent_class=SingleValue, output_labels=output_labels)
146 changes: 118 additions & 28 deletions pyiron_workflow/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from importlib import import_module
from sys import version_info

from pyiron_base.interfaces.singleton import Singleton

Expand All @@ -18,9 +19,6 @@
single_value_node,
)

if TYPE_CHECKING:
from pyiron_workflow.node import Node


class Creator(metaclass=Singleton):
"""
Expand All @@ -30,6 +28,8 @@ class Creator(metaclass=Singleton):
"""

def __init__(self):
self._node_packages = {}

self.Executor = Executor

self.Function = Function
Expand All @@ -40,6 +40,13 @@ def __init__(self):
self._workflow = None
self._meta = None

if version_info[0] == 3 and version_info[1] >= 10:
# These modules use syntactic sugar for type hinting that is only supported
# in python >=3.10
# If the CI skips testing on 3.9 gets dropped, we can think about removing
# this if-clause and just letting users of python <3.10 hit an error.
self.register("standard", "pyiron_workflow.node_library.standard")

@property
def Macro(self):
if self._macro is None:
Expand All @@ -56,20 +63,6 @@ def Workflow(self):
self._workflow = Workflow
return self._workflow

@property
def standard(self):
from pyiron_workflow.node_package import NodePackage
from pyiron_workflow.node_library.standard import nodes

return NodePackage(*nodes)

@property
def atomistics(self):
from pyiron_workflow.node_package import NodePackage
from pyiron_workflow.node_library.atomistics import nodes

return NodePackage(*nodes)

@property
def meta(self):
if self._meta is None:
Expand All @@ -78,16 +71,113 @@ def meta(self):
self._meta = meta_nodes
return self._meta

def register(self, domain: str, *nodes: list[type[Node]]):
raise NotImplementedError(
"Registering new node packages is currently not playing well with "
"executors. We hope to return this feature soon."
)
# if domain in self.__dir__():
# raise AttributeError(f"{domain} is already an attribute of {self}")
# from pyiron_workflow.node_package import NodePackage
#
# setattr(self, domain, NodePackage(*nodes))
def __getattr__(self, item):
try:
module = import_module(self._node_packages[item])
from pyiron_workflow.node_package import NodePackage

return NodePackage(*module.nodes)
except KeyError as e:
raise AttributeError(
f"{self.__class__.__name__} could not find attribute {item} -- did you "
f"forget to register node package to this key?"
) from e

def __getstate__(self):
return self.__dict__

def __setstate__(self, state):
self.__dict__ = state

def register(self, domain: str, package_identifier: str) -> None:
"""
Add a new package of nodes under the provided attribute, e.g. after adding
nodes to the domain `"my_nodes"`, and instance of creator can call things like
`creator.my_nodes.some_node_that_is_there()`.

Note: If a macro is going to use a creator, the node registration should be
_inside_ the macro definition to make sure the node actually has access to
those nodes! It also needs to be _able_ to register those nodes, i.e. have
import access to that location, but we don't for that check that.

Args:
domain (str): The attribute name at which to register the new package.
(Note: no sanitizing is done here, so if you provide a string that
won't work as an attribute name, that's your problem.)
package_identifier (str): An identifier for the node package. (Right now
that's just a string version of the path to the module, e.g.
`pyiron_workflow.node_library.standard`.)

Raises:
KeyError: If the domain already exists, but the identifier doesn't match
with the stored identifier.
AttributeError: If you try to register at a domain that is already another
method or attribute of the creator.
ValueError: If the identifier can't be parsed.
"""

if self._package_conflicts_with_existing(domain, package_identifier):
raise KeyError(
f"{domain} is already a registered node package, please choose a "
f"different domain to store these nodes under"
)
elif domain in self.__dir__():
raise AttributeError(f"{domain} is already an attribute of {self}")

self._verify_identifier(package_identifier)

self._node_packages[domain] = package_identifier

def _package_conflicts_with_existing(
self, domain: str, package_identifier: str
) -> bool:
"""
Check if the new package conflict with an existing package at the requested
domain; if there isn't one, or if the new and old packages are identical then
there is no conflict!

Args:
domain (str): The domain at which the new package is attempting to register.
package_identifier (str): The identifier for the new package.

Returns:
(bool): True iff there is a package already at that domain and it is not
the same as the new one.
"""
if domain in self._node_packages.keys():
# If it's already here, it had better be the same package
return package_identifier != self._node_packages[domain]
# We can make "sameness" logic more complex as we allow more sophisticated
# identifiers
else:
# If it's not here already, it can't conflict!
return False

@staticmethod
def _verify_identifier(package_identifier: str):
"""
Logic for verifying whether new package identifiers will actually be usable for
creating node packages when their domain is called. Lets us fail early in
registration.

Right now, we just make sure it's a string from which we can import a list of
nodes.
"""
from pyiron_workflow.node import Node

try:
module = import_module(package_identifier)
nodes = module.nodes
if not all(issubclass(node, Node) for node in nodes):
raise TypeError(
f"At least one node in {nodes} was not of the type {Node.__name__}"
)
except Exception as e:
raise ValueError(
f"The package identifier is {package_identifier} is not valid. Please "
f"ensure it is an importable module with a list of {Node.__name__} "
f"objects stored in the variable `nodes`."
) from e


class Wrappers(metaclass=Singleton):
Expand Down
25 changes: 25 additions & 0 deletions pyiron_workflow/node_library/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
For graphical representations of data.
"""

from __future__ import annotations

from typing import Optional

import numpy as np

from pyiron_workflow.function import single_value_node


@single_value_node(output_labels="fig")
def scatter(
x: Optional[list | np.ndarray] = None, y: Optional[list | np.ndarray] = None
):
from matplotlib import pyplot as plt

return plt.scatter(x, y)


nodes = [
scatter,
]
Loading