From ff3226bfbe25819c9784e4dbf57233649647fccb Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 11:56:27 -0800 Subject: [PATCH 01/43] Use typing.Callable instead of callable Signed-off-by: liamhuber --- pyiron_workflow/executors/cloudpickleprocesspool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index 038c4c45..cd11b072 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,6 +1,7 @@ from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info +from typing import Callable import cloudpickle @@ -14,7 +15,7 @@ def result(self, timeout=None): class _CloudPickledCallable: - def __init__(self, fnc: callable): + def __init__(self, fnc: Callable): self.fnc_serial = cloudpickle.dumps(fnc) def __call__(self, /, dumped_args, dumped_kwargs): From 5b7e9c7f3eeff046dd620ac16b5da448198e4c90 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 12:07:53 -0800 Subject: [PATCH 02/43] Ignore erroneous error typing._UnionGenericAlias definitively _does_ exist. Signed-off-by: liamhuber --- pyiron_workflow/type_hinting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index ccf15be2..563bfe2f 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,7 +28,11 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance(type_hint, types.UnionType | typing._UnionGenericAlias): + if isinstance( + type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore + # mypy complains because it thinks typing._UnionGenericAlias doesn't exist + # It definitely does, and we may be able to remove this once mypy catches up + ): return typing.get_args(type_hint) else: return (type_hint,) From aa3c143b1df131bcf085697338e0546f650ab828 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 14:54:02 -0800 Subject: [PATCH 03/43] Hint a tuple, don't return one Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 04e9e9f2..7dbe7be3 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -464,7 +464,9 @@ def _valid_connection(self, other: DataChannel) -> bool: def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint - def _figure_out_who_is_who(self, other: DataChannel) -> (OutputData, InputData): + def _figure_out_who_is_who( + self, other: DataChannel + ) -> tuple[OutputData, InputData]: return (self, other) if isinstance(self, OutputData) else (other, self) def __str__(self): From 534a2c68d7c9a58333c127f89eb7b4dafc0614e4 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 15:03:27 -0800 Subject: [PATCH 04/43] Hint typing.Callable instead of callable Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 6 +++--- pyiron_workflow/mixin/preview.py | 7 ++++--- pyiron_workflow/nodes/composite.py | 8 ++++---- pyiron_workflow/nodes/function.py | 10 +++++----- pyiron_workflow/nodes/macro.py | 10 +++++----- pyiron_workflow/nodes/standard.py | 3 ++- pyiron_workflow/nodes/transform.py | 4 ++-- pyiron_workflow/topology.py | 4 ++-- 8 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 7dbe7be3..7c0d3e90 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -569,7 +569,7 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): """ Make a new input signal channel. @@ -616,7 +616,7 @@ def _has_required_args(func): ) @property - def callback(self) -> callable: + def callback(self) -> typing.Callable: return getattr(self.owner, self._callback) def __call__(self, other: OutputSignal | None = None) -> None: @@ -639,7 +639,7 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 21463d4c..bec7dbe5 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -18,6 +18,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, ClassVar, get_args, get_type_hints, @@ -81,7 +82,7 @@ def preview_io(cls) -> DotDict[str, dict]: ) -def builds_class_io(subclass_factory: callable[..., type[HasIOPreview]]): +def builds_class_io(subclass_factory: Callable[..., type[HasIOPreview]]): """ A decorator for factories producing subclasses of `HasIOPreview` to invoke :meth:`preview_io` after the class is created, thus ensuring the IO has been @@ -129,7 +130,7 @@ class ScrapesIO(HasIOPreview, ABC): @classmethod @abstractmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: """Must return a static method.""" _output_labels: ClassVar[tuple[str] | None] = None # None: scrape them @@ -287,7 +288,7 @@ def _validate_return_count(cls): ) from type_error @staticmethod - def _io_defining_documentation(io_defining_function: callable, title: str): + def _io_defining_documentation(io_defining_function: Callable, title: str): """ A helper method for building a docstring for classes that have their IO defined by some function. diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 7f745e9b..74f4abb8 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -7,7 +7,7 @@ from abc import ABC from time import sleep -from typing import TYPE_CHECKING, Literal +from typing import Callable, Literal, TYPE_CHECKING from pyiron_snippets.colors import SeabornColors from pyiron_snippets.dotdict import DotDict @@ -450,7 +450,7 @@ def graph_as_dict(self) -> dict: return _get_graph_as_dict(self) def _get_connections_as_strings( - self, panel_getter: callable + self, panel_getter: Callable ) -> list[tuple[tuple[str, str], tuple[str, str]]]: """ Connections between children in string representation based on labels. @@ -520,8 +520,8 @@ def __setstate__(self, state): def _restore_connections_from_strings( nodes: dict[str, Node] | DotDict[str, Node], connections: list[tuple[tuple[str, str], tuple[str, str]]], - input_panel_getter: callable, - output_panel_getter: callable, + input_panel_getter: Callable, + output_panel_getter: Callable, ) -> None: """ Set connections among a dictionary of nodes. diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index cd3d9f31..484509a2 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from inspect import getsource -from typing import Any +from typing import Any, Callable from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory @@ -300,11 +300,11 @@ class Function(StaticNode, ScrapesIO, ABC): @staticmethod @abstractmethod - def node_function(**kwargs) -> callable: + def node_function(**kwargs) -> Callable: """What the node _does_.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.node_function @classmethod @@ -351,7 +351,7 @@ def _extra_info(cls) -> str: @classfactory def function_node_factory( - node_function: callable, + node_function: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -429,7 +429,7 @@ def decorator(node_function): def function_node( - node_function: callable, + node_function: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 566605a5..b85d88ba 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -8,7 +8,7 @@ import re from abc import ABC, abstractmethod from inspect import getsource -from typing import TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from pyiron_snippets.factory import classfactory @@ -271,11 +271,11 @@ def _setup_node(self) -> None: @staticmethod @abstractmethod - def graph_creator(self, *args, **kwargs) -> callable: + def graph_creator(self, *args, **kwargs) -> Callable: """Build the graph the node will run.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.graph_creator _io_defining_function_uses_self = True @@ -466,7 +466,7 @@ def _extra_info(cls) -> str: @classfactory def macro_node_factory( - graph_creator: callable, + graph_creator: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -536,7 +536,7 @@ def decorator(graph_creator): def macro_node( - graph_creator: callable, + graph_creator: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index 8c119944..f753a402 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -9,6 +9,7 @@ import shutil from pathlib import Path from time import sleep +from typing import Callable from pyiron_workflow.channels import NOT_DATA, OutputSignal from pyiron_workflow.nodes.function import Function, as_function_node @@ -167,7 +168,7 @@ def ChangeDirectory( @as_function_node -def PureCall(fnc: callable): +def PureCall(fnc: Callable): """ Return a call without any arguments diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 97befbbb..a1710416 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import MISSING from dataclasses import dataclass as as_dataclass -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar from pandas import DataFrame from pyiron_snippets.colors import SeabornColors @@ -65,7 +65,7 @@ class ToManyOutputs(Transformer, ABC): # Must be commensurate with the dictionary returned by transform_to_output @abstractmethod - def _on_run(self, input_object) -> callable[..., Any | tuple]: + def _on_run(self, input_object) -> Callable[..., Any | tuple]: """Must take the single object to be transformed""" @property diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index c60c9131..08db9139 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten @@ -90,7 +90,7 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: def _set_new_run_connections_with_fallback_recovery( - connection_creator: callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] + connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] ): """ Given a function that takes a dictionary of unconnected nodes, connects their From 85f95d28865e7ddddb748819b66aca486bdb6349 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 10:42:44 -0800 Subject: [PATCH 05/43] Expose the Self typing tool for all versions Signed-off-by: liamhuber --- pyiron_workflow/compatibility.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 pyiron_workflow/compatibility.py diff --git a/pyiron_workflow/compatibility.py b/pyiron_workflow/compatibility.py new file mode 100644 index 00000000..28b7f773 --- /dev/null +++ b/pyiron_workflow/compatibility.py @@ -0,0 +1,6 @@ +from sys import version_info + +if version_info.minor < 11: + from typing_extensions import Self as Self +else: + from typing import Self as Self From 9895187e1ce3bd6f59cbd65ed46fbe602cd1ba84 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 12:24:06 -0800 Subject: [PATCH 06/43] Add a mypy job Based on @jan-janssen's jobs for other pyiron repos Signed-off-by: liamhuber --- .github/workflows/push-pull.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index 9178f1da..b1d47dae 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -19,3 +19,18 @@ jobs: alternate-tests-env-files: .ci_support/lower_bound.yml alternate-tests-python-version: '3.10' alternate-tests-dir: tests/unit + + mypy: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + architecture: x64 + - name: Checkout + uses: actions/checkout@v4 + - name: Install mypy + run: pip install mypy + - name: Test + run: mypy --ignore-missing-imports ${{ github.event.repository.name }} From c5947607d53c006b2c3f18f9d4487e5ac81934e2 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Wed, 8 Jan 2025 12:50:51 -0800 Subject: [PATCH 07/43] `mypy` channels (#534) * Leverage generics for connection partners Signed-off-by: liamhuber * Break apart connection error message So we only reference type hints when they're there Signed-off-by: liamhuber * Hint connections type more specifically Signed-off-by: liamhuber * Hint disconnect more specifically Signed-off-by: liamhuber * Use Self in disconnection hints Signed-off-by: liamhuber * Use Self to hint value_receiver Signed-off-by: liamhuber * Devolve responsibility for connection validity Otherwise mypy has trouble telling that data channels really are operating on a connection partner, since the `super()` call could wind up pointing anywhere. Signed-off-by: liamhuber * Fix typing in channel tests Signed-off-by: liamhuber * :bug: Return the message Signed-off-by: liamhuber * Fix typing in figuring out who is I/O Signed-off-by: liamhuber * Recast connection parters as class method mypy complained about the class-level attribute access I was using to get around circular references. This is a bit more verbose, but otherwise a fine alternative. Signed-off-by: liamhuber * Match Accumulating input signal call to parent Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 167 +++++++++++++++++++++++------------- tests/unit/test_channels.py | 63 +++++++++----- 2 files changed, 148 insertions(+), 82 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 7c0d3e90..a95f3806 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -14,6 +14,7 @@ from pyiron_snippets.singleton import Singleton +from pyiron_workflow.compatibility import Self from pyiron_workflow.mixin.display_state import HasStateDisplay from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel from pyiron_workflow.type_hinting import ( @@ -25,11 +26,24 @@ from pyiron_workflow.io import HasIO -class ChannelConnectionError(Exception): +class ChannelError(Exception): pass -class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): +class ChannelConnectionError(ChannelError): + pass + + +ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel") + + +class Channel( + HasChannel, + HasLabel, + HasStateDisplay, + typing.Generic[ConnectionPartner], + ABC +): """ Channels facilitate the flow of information (data or control signals) into and out of :class:`HasIO` objects (namely nodes). @@ -37,12 +51,9 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): They must have an identifier (`label: str`) and belong to an `owner: pyiron_workflow.io.HasIO`. - Non-abstract channel classes should come in input/output pairs and specify the - a necessary ancestor for instances they can connect to - (`connection_partner_type: type[Channel]`). Channels may form (:meth:`connect`/:meth:`disconnect`) and store - (:attr:`connections: list[Channel]`) connections with other channels. + (:attr:`connections`) connections with other channels. This connection information is reflexive, and is duplicated to be stored on _both_ channels in the form of a reference to their counterpart in the connection. @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirement is that the connecting channels form a - "conjugate pair" of classes, i.e. they are children of each other's partner class - (:attr:`connection_partner_type: type[Channel]`) -- input/output connects to - output/input. + In this abstract class the only requirements are that the connecting channels form + a "conjugate pair" of classes, i.e. they are children of each other's partner class + and thus have the same "flavor", but are an input/output pair; and that they define + a string representation. Iterating over channels yields their connections. @@ -80,7 +91,7 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[Channel] = [] + self.connections: list[ConnectionPartner] = [] @property def label(self) -> str: @@ -90,12 +101,12 @@ def label(self) -> str: def __str__(self): pass - @property + @classmethod @abstractmethod - def connection_partner_type(self) -> type[Channel]: + def connection_partner_type(cls) -> type[ConnectionPartner]: """ - Input and output class pairs must specify a parent class for their valid - connection partners. + The class forming a conjugate pair with this channel class -- i.e. the same + "flavor" of channel, but opposite in I/O. """ @property @@ -108,21 +119,18 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - def _valid_connection(self, other: Channel) -> bool: + @abstractmethod + def _valid_connection(self, other: object) -> bool: """ Logic for determining if a connection is valid. - - Connections only allowed to instances with the right parent type -- i.e. - connection pairs should be an input/output. """ - return isinstance(other, self.connection_partner_type) - def connect(self, *others: Channel) -> None: + def connect(self, *others: ConnectionPartner) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :attr:`connection_partner_type`. + :meth:`connection_partner_type()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -145,24 +153,28 @@ def connect(self, *others: Channel) -> None: self.connections.insert(0, other) other.connections.insert(0, self) else: - if isinstance(other, self.connection_partner_type): + if isinstance(other, self.connection_partner_type()): raise ChannelConnectionError( - f"The channel {other.full_label} ({other.__class__.__name__}" - f") has the correct type " - f"({self.connection_partner_type.__name__}) to connect with " - f"{self.full_label} ({self.__class__.__name__}), but is not " - f"a valid connection. Please check type hints, etc." - f"{other.full_label}.type_hint = {other.type_hint}; " - f"{self.full_label}.type_hint = {self.type_hint}" + self._connection_partner_failure_message(other) ) from None else: raise TypeError( - f"Can only connect to {self.connection_partner_type.__name__} " - f"objects, but {self.full_label} ({self.__class__.__name__}) " + f"Can only connect to {self.connection_partner_type()} " + f"objects, but {self.full_label} ({self.__class__}) " f"got {other} ({type(other)})" ) - def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: + def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: + return ( + f"The channel {other.full_label} ({other.__class__}) has the " + f"correct type ({self.connection_partner_type()}) to connect with " + f"{self.full_label} ({self.__class__}), but is not a valid " + f"connection." + ) + + def disconnect( + self, *others: ConnectionPartner + ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -182,7 +194,9 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Channel, Channel]]: + def disconnect_all( + self + ) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -257,8 +271,9 @@ def __bool__(self): NOT_DATA = NotData() +DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") -class DataChannel(Channel, ABC): +class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -331,7 +346,7 @@ class DataChannel(Channel, ABC): when this channel is a value receiver. This can potentially be expensive, so consider deactivating strict hints everywhere for production runs. (Default is True, raise exceptions when type hints get violated.) - value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of + value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of the same class whose value will always get updated when this channel's value gets updated. """ @@ -343,7 +358,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: InputData | None = None, + value_receiver: Self | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -352,7 +367,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver = value_receiver + self.value_receiver: Self = value_receiver @property def value(self): @@ -379,7 +394,7 @@ def _type_check_new_value(self, new_value): ) @property - def value_receiver(self) -> InputData | OutputData | None: + def value_receiver(self) -> Self | None: """ Another data channel of the same type to whom new values are always pushed (without type checking of any sort, not even when forming the couple!) @@ -390,7 +405,7 @@ def value_receiver(self) -> InputData | OutputData | None: return self._value_receiver @value_receiver.setter - def value_receiver(self, new_partner: InputData | OutputData | None): + def value_receiver(self, new_partner: Self | None): if new_partner is not None: if not isinstance(new_partner, self.__class__): raise TypeError( @@ -445,8 +460,8 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: DataChannel) -> bool: - if super()._valid_connection(other): + def _valid_connection(self, other: object) -> bool: + if isinstance(other, self.connection_partner_type()): if self._both_typed(other): out, inp = self._figure_out_who_is_who(other) if not inp.strict_hints: @@ -461,13 +476,32 @@ def _valid_connection(self, other: DataChannel) -> bool: else: return False - def _both_typed(self, other: DataChannel) -> bool: + def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str: + msg = super()._connection_partner_failure_message(other) + msg += ( + f"Please check type hints, etc. {other.full_label}.type_hint = " + f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" + ) + return msg + + def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataChannel + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: - return (self, other) if isinstance(self, OutputData) else (other, self) + if isinstance(self, InputData) and isinstance(other, OutputData): + return other, self + elif isinstance(self, OutputData) and isinstance(other, InputData): + return self, other + else: + raise ChannelError( + f"This should be unreachable; data channel conjugate pairs should " + f"always be input/output, but got {type(self)} for {self.full_label} " + f"and {type(other)} for {other.full_label}. If you don't believe you " + f"are responsible for this error, please contact the maintainers via " + f"GitHub." + ) def __str__(self): return str(self.value) @@ -491,9 +525,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel): - @property - def connection_partner_type(self): +class InputData(DataChannel["OutputData"]): + + @classmethod + def connection_partner_type(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -530,13 +565,17 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel): - @property - def connection_partner_type(self): +class OutputData(DataChannel["InputData"]): + @classmethod + def connection_partner_type(cls) -> type[InputData]: return InputData -class SignalChannel(Channel, ABC): +SignalConnectionPartner = typing.TypeVar( + "SignalConnectionPartner", bound="SignalChannel" +) + +class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -555,15 +594,15 @@ class SignalChannel(Channel, ABC): def __call__(self) -> None: pass + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) + class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel): - @property - def connection_partner_type(self): - return OutputSignal +class InputSignal(SignalChannel["OutputSignal"]): def __init__( self, @@ -591,6 +630,10 @@ def __init__( f"all args are optional: {self._all_args_arg_optional(callback)} " ) + @classmethod + def connection_partner_type(cls) -> type[OutputSignal]: + return OutputSignal + def _is_method_on_owner(self, callback): try: return callback == getattr(self.owner, callback.__name__) @@ -644,14 +687,15 @@ def __init__( super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() - def __call__(self, other: OutputSignal) -> None: + def __call__(self, other: OutputSignal | None = None) -> None: """ Fire callback iff you have received at least one signal from each of your current connections. Resets the collection of received signals when firing. """ - self.received_signals.update([other.scoped_label]) + if isinstance(other, OutputSignal): + self.received_signals.update([other.scoped_label]) if ( len( set(c.scoped_label for c in self.connections).difference( @@ -675,9 +719,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel): - @property - def connection_partner_type(self): +class OutputSignal(SignalChannel["InputSignal"]): + + @classmethod + def connection_partner_type(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index eaeb4a85..151a97fa 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pyiron_workflow.channels import ( @@ -6,6 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, + ConnectionPartner, InputData, InputSignal, OutputData, @@ -30,25 +33,24 @@ def data_input_locked(self): return self.locked -class InputChannel(Channel): +class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" - def __str__(self): return "non-abstract input" - @property - def connection_partner_type(self) -> type[Channel]: - return OutputChannel + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) -class OutputChannel(Channel): - """Just to de-abstract the base class""" +class InputChannel(DummyChannel["OutputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[OutputChannel]: + return OutputChannel - def __str__(self): - return "non-abstract output" - @property - def connection_partner_type(self) -> type[Channel]: +class OutputChannel(DummyChannel["InputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[InputChannel]: return InputChannel @@ -389,26 +391,44 @@ def test_aggregating_call(self): owner = DummyOwner() agg = AccumulatingInputSignal(label="agg", owner=owner, callback=owner.update) - with self.assertRaises( - TypeError, - msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional", - ): - agg() - out2 = OutputSignal(label="out2", owner=DummyOwner()) agg.connect(self.out, out2) + out_unrelated = OutputSignal(label="out_unrelated", owner=DummyOwner()) + + signals_sent = 0 self.assertEqual( 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, len(agg.received_signals), msg="Sanity check on initial conditions" + signals_sent, + len(agg.received_signals), + msg="Sanity check on initial conditions" ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") + agg() + signals_sent += 0 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + agg(out_unrelated) + signals_sent += 1 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + self.out() - self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") + signals_sent += 1 + self.assertEqual( + signals_sent, + len(agg.received_signals), + msg="Signals from other channels should be received" + ) self.assertListEqual( [0], owner.foo, @@ -416,8 +436,9 @@ def test_aggregating_call(self): ) self.out() + signals_sent += 0 self.assertEqual( - 1, + signals_sent, len(agg.received_signals), msg="Repeatedly receiving the same signal should have no effect", ) From 6279797c391b4631f1bfa374008f68c8ee62aad9 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 12:58:21 -0800 Subject: [PATCH 08/43] Move Ruff jobs into the main push-pull script This is just a little QoL thing; the current script runs the jobs twice every time I push, and it's annoying me. Signed-off-by: liamhuber --- .github/workflows/push-pull.yml | 16 ++++++++++++++++ .github/workflows/ruff.yml | 17 ----------------- 2 files changed, 16 insertions(+), 17 deletions(-) delete mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index b1d47dae..6418e835 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -34,3 +34,19 @@ jobs: run: pip install mypy - name: Test run: mypy --ignore-missing-imports ${{ github.event.repository.name }} + + ruff-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check + + ruff-sort-imports: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check --select I --fix --diff \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 68c0ec0d..00000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Ruff -on: [ push, pull_request ] -jobs: - ruff-check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check - ruff-sort-imports: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check --select I --fix --diff From e13caf5e0683fb977a9d58036fa40de3d2e8ca0c Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 13:00:30 -0800 Subject: [PATCH 09/43] Ruff: import Callable from collections.abc Signed-off-by: liamhuber --- pyiron_workflow/executors/cloudpickleprocesspool.py | 2 +- pyiron_workflow/mixin/preview.py | 2 +- pyiron_workflow/nodes/composite.py | 3 ++- pyiron_workflow/nodes/function.py | 3 ++- pyiron_workflow/nodes/macro.py | 3 ++- pyiron_workflow/nodes/standard.py | 2 +- pyiron_workflow/nodes/transform.py | 3 ++- pyiron_workflow/topology.py | 3 ++- 8 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index cd11b072..983dc525 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,7 +1,7 @@ +from collections.abc import Callable from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info -from typing import Callable import cloudpickle diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index bec7dbe5..556af8cd 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -14,11 +14,11 @@ import inspect from abc import ABC, abstractmethod +from collections.abc import Callable from functools import lru_cache, wraps from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, get_args, get_type_hints, diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 74f4abb8..11d50583 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -6,8 +6,9 @@ from __future__ import annotations from abc import ABC +from collections.abc import Callable from time import sleep -from typing import Callable, Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from pyiron_snippets.colors import SeabornColors from pyiron_snippets.dotdict import DotDict diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index 484509a2..8000a6d9 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource -from typing import Any, Callable +from typing import Any from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index b85d88ba..527bd5de 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -7,8 +7,9 @@ import re from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING from pyiron_snippets.factory import classfactory diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index f753a402..e9b4c683 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -7,9 +7,9 @@ import os import random import shutil +from collections.abc import Callable from pathlib import Path from time import sleep -from typing import Callable from pyiron_workflow.channels import NOT_DATA, OutputSignal from pyiron_workflow.nodes.function import Function, as_function_node diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index a1710416..8852b426 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -6,9 +6,10 @@ import itertools from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import MISSING from dataclasses import dataclass as as_dataclass -from typing import Any, Callable, ClassVar +from typing import Any, ClassVar from pandas import DataFrame from pyiron_snippets.colors import SeabornColors diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index 08db9139..bffd590f 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Callable, TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten From 4d242b6fd8554b026917623eb0222a40a4d8fa8d Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 13:03:44 -0800 Subject: [PATCH 10/43] black Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 16 ++++++---------- pyiron_workflow/type_hinting.py | 3 ++- tests/unit/test_channels.py | 9 +++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index a95f3806..a7883326 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -38,11 +38,7 @@ class ChannelConnectionError(ChannelError): class Channel( - HasChannel, - HasLabel, - HasStateDisplay, - typing.Generic[ConnectionPartner], - ABC + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC ): """ Channels facilitate the flow of information (data or control signals) into and @@ -173,7 +169,7 @@ def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: ) def disconnect( - self, *others: ConnectionPartner + self, *others: ConnectionPartner ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers @@ -194,9 +190,7 @@ def disconnect( destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all( - self - ) -> list[tuple[Self, ConnectionPartner]]: + def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -273,6 +267,7 @@ def __bool__(self): DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") + class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -488,7 +483,7 @@ def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataConnectionPartner + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: if isinstance(self, InputData) and isinstance(other, OutputData): return other, self @@ -575,6 +570,7 @@ def connection_partner_type(cls) -> type[InputData]: "SignalConnectionPartner", bound="SignalChannel" ) + class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 563bfe2f..28af408b 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -29,7 +29,8 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: if isinstance( - type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore + type_hint, + types.UnionType | typing._UnionGenericAlias, # type: ignore # mypy complains because it thinks typing._UnionGenericAlias doesn't exist # It definitely does, and we may be able to remove this once mypy catches up ): diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index 151a97fa..dce71e3c 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -35,6 +35,7 @@ def data_input_locked(self): class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" + def __str__(self): return "non-abstract input" @@ -403,7 +404,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Sanity check on initial conditions" + msg="Sanity check on initial conditions", ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") @@ -412,14 +413,14 @@ def test_aggregating_call(self): self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) agg(out_unrelated) signals_sent += 1 self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) self.out() @@ -427,7 +428,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Signals from other channels should be received" + msg="Signals from other channels should be received", ) self.assertListEqual( [0], From 71c46da0e0f65f601e847c8a00076c5e284f722d Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 9 Jan 2025 09:42:14 -0800 Subject: [PATCH 11/43] Drop the private type hint (#535) It was necessary for python<3.10, but we dropped support for that, so we can get rid of the ugly, non-public hint. Signed-off-by: liamhuber --- pyiron_workflow/type_hinting.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 28af408b..66ae0ce2 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,15 +28,9 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance( - type_hint, - types.UnionType | typing._UnionGenericAlias, # type: ignore - # mypy complains because it thinks typing._UnionGenericAlias doesn't exist - # It definitely does, and we may be able to remove this once mypy catches up - ): + if isinstance(type_hint, types.UnionType): return typing.get_args(type_hint) - else: - return (type_hint,) + return (type_hint,) def type_hint_is_as_or_more_specific_than(hint, other) -> bool: From 214c6e2bdd43ffaf3bfb3245f0d8806afdfba1f3 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 09:55:49 -0800 Subject: [PATCH 12/43] `mypy` channels redux (#536) * Refactor: rename Move from "partner" language to "conjugate" language Signed-off-by: liamhuber * Explicitly decompose conjugate behaviour Into flavor and IO components Signed-off-by: liamhuber * Tidying Signed-off-by: liamhuber * Narrow hint on connection copying Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 93 ++++++++++++++++++++----------------- tests/unit/test_channels.py | 10 ++-- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index a7883326..af602a31 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -34,11 +34,14 @@ class ChannelConnectionError(ChannelError): pass -ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel") +ConjugateType = typing.TypeVar("ConjugateType", bound="Channel") +InputType = typing.TypeVar("InputType", bound="InputChannel") +OutputType = typing.TypeVar("OutputType", bound="OutputChannel") +FlavorType = typing.TypeVar("FlavorType", bound="FlavorChannel") class Channel( - HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConjugateType], ABC ): """ Channels facilitate the flow of information (data or control signals) into and @@ -58,10 +61,10 @@ class Channel( these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirements are that the connecting channels form - a "conjugate pair" of classes, i.e. they are children of each other's partner class - and thus have the same "flavor", but are an input/output pair; and that they define - a string representation. + Child classes must specify a conjugate class in order to enforce connection + conjugate pairs which have the same "flavor" (e.g. "data" or "signal"), and + opposite "direction" ("input" vs "output"). And they must define a string + representation. Iterating over channels yields their connections. @@ -87,7 +90,7 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[ConnectionPartner] = [] + self.connections: list[ConjugateType] = [] @property def label(self) -> str: @@ -99,7 +102,7 @@ def __str__(self): @classmethod @abstractmethod - def connection_partner_type(cls) -> type[ConnectionPartner]: + def connection_conjugate(cls) -> type[ConjugateType]: """ The class forming a conjugate pair with this channel class -- i.e. the same "flavor" of channel, but opposite in I/O. @@ -121,12 +124,12 @@ def _valid_connection(self, other: object) -> bool: Logic for determining if a connection is valid. """ - def connect(self, *others: ConnectionPartner) -> None: + def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :meth:`connection_partner_type()`. + :meth:`connection_conjugate()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -149,28 +152,26 @@ def connect(self, *others: ConnectionPartner) -> None: self.connections.insert(0, other) other.connections.insert(0, self) else: - if isinstance(other, self.connection_partner_type()): + if isinstance(other, self.connection_conjugate()): raise ChannelConnectionError( - self._connection_partner_failure_message(other) + self._connection_conjugate_failure_message(other) ) from None else: raise TypeError( - f"Can only connect to {self.connection_partner_type()} " + f"Can only connect to {self.connection_conjugate()} " f"objects, but {self.full_label} ({self.__class__}) " f"got {other} ({type(other)})" ) - def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: + def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: return ( f"The channel {other.full_label} ({other.__class__}) has the " - f"correct type ({self.connection_partner_type()}) to connect with " + f"correct type ({self.connection_conjugate()}) to connect with " f"{self.full_label} ({self.__class__}), but is not a valid " f"connection." ) - def disconnect( - self, *others: ConnectionPartner - ) -> list[tuple[Self, ConnectionPartner]]: + def disconnect(self, *others: ConjugateType) -> list[tuple[Self, ConjugateType]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -190,7 +191,7 @@ def disconnect( destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]: + def disconnect_all(self) -> list[tuple[Self, ConjugateType]]: """ Disconnect from all other channels currently in the connections list. """ @@ -207,10 +208,10 @@ def __iter__(self): return self.connections.__iter__() @property - def channel(self) -> Channel: + def channel(self) -> Self: return self - def copy_connections(self, other: Channel) -> None: + def copy_connections(self, other: Self) -> None: """ Adds all the connections in another channel to this channel's connections. @@ -243,6 +244,18 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) +class FlavorChannel(Channel[FlavorType], ABC): + """Abstract base for all flavor-specific channels.""" + + +class InputChannel(Channel[OutputType], ABC): + """Mixin for input channels.""" + + +class OutputChannel(Channel[InputType], ABC): + """Mixin for output channels.""" + + class NotData(metaclass=Singleton): """ This class exists purely to initialize data channel values where no default value @@ -265,10 +278,8 @@ def __bool__(self): NOT_DATA = NotData() -DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") - -class DataChannel(Channel[DataConnectionPartner], ABC): +class DataChannel(FlavorChannel["DataChannel"], ABC): """ Data channels control the flow of data on the graph. @@ -456,7 +467,7 @@ def _has_hint(self) -> bool: return self.type_hint is not None def _valid_connection(self, other: object) -> bool: - if isinstance(other, self.connection_partner_type()): + if isinstance(other, self.connection_conjugate()): if self._both_typed(other): out, inp = self._figure_out_who_is_who(other) if not inp.strict_hints: @@ -471,19 +482,19 @@ def _valid_connection(self, other: object) -> bool: else: return False - def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str: - msg = super()._connection_partner_failure_message(other) + def _connection_conjugate_failure_message(self, other: DataChannel) -> str: + msg = super()._connection_conjugate_failure_message(other) msg += ( f"Please check type hints, etc. {other.full_label}.type_hint = " f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" ) return msg - def _both_typed(self, other: DataConnectionPartner | Self) -> bool: + def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataConnectionPartner + self, other: DataChannel ) -> tuple[OutputData, InputData]: if isinstance(self, InputData) and isinstance(other, OutputData): return other, self @@ -520,10 +531,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel["OutputData"]): +class InputData(DataChannel, InputChannel["OutputData"]): @classmethod - def connection_partner_type(cls) -> type[OutputData]: + def connection_conjugate(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -560,18 +571,16 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel["InputData"]): +class OutputData(DataChannel, OutputChannel["InputData"]): @classmethod - def connection_partner_type(cls) -> type[InputData]: + def connection_conjugate(cls) -> type[InputData]: return InputData -SignalConnectionPartner = typing.TypeVar( - "SignalConnectionPartner", bound="SignalChannel" -) +SignalType = typing.TypeVar("SignalType", bound="SignalChannel") -class SignalChannel(Channel[SignalConnectionPartner], ABC): +class SignalChannel(FlavorChannel[SignalType], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -591,14 +600,14 @@ def __call__(self) -> None: pass def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_partner_type()) + return isinstance(other, self.connection_conjugate()) class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel["OutputSignal"]): +class InputSignal(SignalChannel["OutputSignal"], InputChannel["OutputSignal"]): def __init__( self, @@ -627,7 +636,7 @@ def __init__( ) @classmethod - def connection_partner_type(cls) -> type[OutputSignal]: + def connection_conjugate(cls) -> type[OutputSignal]: return OutputSignal def _is_method_on_owner(self, callback): @@ -715,10 +724,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel["InputSignal"]): +class OutputSignal(SignalChannel["InputSignal"], OutputChannel["InputSignal"]): @classmethod - def connection_partner_type(cls) -> type[InputSignal]: + def connection_conjugate(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index dce71e3c..bb6c4690 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -8,7 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, - ConnectionPartner, + ConjugateType, InputData, InputSignal, OutputData, @@ -33,25 +33,25 @@ def data_input_locked(self): return self.locked -class DummyChannel(Channel[ConnectionPartner]): +class DummyChannel(Channel[ConjugateType]): """Just to de-abstract the base class""" def __str__(self): return "non-abstract input" def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_partner_type()) + return isinstance(other, self.connection_conjugate()) class InputChannel(DummyChannel["OutputChannel"]): @classmethod - def connection_partner_type(cls) -> type[OutputChannel]: + def connection_conjugate(cls) -> type[OutputChannel]: return OutputChannel class OutputChannel(DummyChannel["InputChannel"]): @classmethod - def connection_partner_type(cls) -> type[InputChannel]: + def connection_conjugate(cls) -> type[InputChannel]: return InputChannel From fc41dfa296f5f0046c302a810be0afd4baf018a0 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 09:56:12 -0800 Subject: [PATCH 13/43] Apply hints to IO panels (#537) * Refactor: rename Move from "partner" language to "conjugate" language Signed-off-by: liamhuber * Explicitly decompose conjugate behaviour Into flavor and IO components Signed-off-by: liamhuber * Tidying Signed-off-by: liamhuber * Narrow hint on connection copying Signed-off-by: liamhuber * Apply hints to IO panels Signed-off-by: liamhuber * Narrow type Signed-off-by: liamhuber * Don't reuse variable Signed-off-by: liamhuber * Ruff: sort imports Signed-off-by: liamhuber * :bug: fix type hint Signed-off-by: liamhuber * Add more hints Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/io.py | 73 ++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 0293bb6c..d1cb442d 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,7 +9,7 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -20,8 +20,10 @@ DataChannel, InputData, InputSignal, + InputType, OutputData, OutputSignal, + OutputType, SignalChannel, ) from pyiron_workflow.logging import logger @@ -32,8 +34,11 @@ HasRun, ) +OwnedType = TypeVar("OwnedType", bound=Channel) +OwnedConjugate = TypeVar("OwnedConjugate", bound=Channel) -class IO(HasStateDisplay, ABC): + +class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): """ IO is a convenience layer for holding and accessing multiple input/output channels. It allows key and dot-based access to the underlying channels. @@ -52,7 +57,9 @@ class IO(HasStateDisplay, ABC): be assigned with a simple `=`. """ - def __init__(self, *channels: Channel): + channel_dict: DotDict[str, OwnedType] + + def __init__(self, *channels: OwnedType): self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -63,15 +70,15 @@ def __init__(self, *channels: Channel): @property @abstractmethod - def _channel_class(self) -> type(Channel): + def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: Channel, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> Channel: + def __getattr__(self, item) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -97,20 +104,20 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: Channel, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> Channel: + def __getitem__(self, item) -> OwnedType: return self.__getattr__(item) def __setitem__(self, key, value): self.__setattr__(key, value) @property - def connections(self) -> list[Channel]: + def connections(self) -> list[OwnedConjugate]: """All the unique connections across all channels""" return list( set([connection for channel in self for connection in channel.connections]) @@ -124,7 +131,7 @@ def connected(self): def fully_connected(self): return all([c.connected for c in self]) - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: """ Disconnect all connections that owned channels have. @@ -173,7 +180,15 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class DataIO(IO, ABC): +class InputsIO(IO[InputType, OutputType], ABC): + pass + + +class OutputsIO(IO[OutputType, InputType], ABC): + pass + + +class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value @@ -195,9 +210,9 @@ def deactivate_strict_hints(self): [c.deactivate_strict_hints() for c in self] -class Inputs(DataIO): +class Inputs(InputsIO, DataIO): @property - def _channel_class(self) -> type(InputData): + def _channel_class(self) -> type[InputData]: return InputData def fetch(self): @@ -205,13 +220,13 @@ def fetch(self): c.fetch() -class Outputs(DataIO): +class Outputs(OutputsIO, DataIO): @property - def _channel_class(self) -> type(OutputData): + def _channel_class(self) -> type[OutputData]: return OutputData -class SignalIO(IO, ABC): +class SignalIO(IO[SignalChannel, SignalChannel], ABC): def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " @@ -220,12 +235,12 @@ def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: ) -class InputSignals(SignalIO): +class InputSignals(InputsIO, SignalIO): @property - def _channel_class(self) -> type(InputSignal): + def _channel_class(self) -> type[InputSignal]: return InputSignal - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """Disconnect all `run` and `accumulate_and_run` signals, if they exist.""" disconnected = [] with contextlib.suppress(AttributeError): @@ -235,9 +250,9 @@ def disconnect_run(self) -> list[tuple[Channel, Channel]]: return disconnected -class OutputSignals(SignalIO): +class OutputSignals(OutputsIO, SignalIO): @property - def _channel_class(self) -> type(OutputSignal): + def _channel_class(self) -> type[OutputSignal]: return OutputSignal @@ -254,7 +269,7 @@ def __init__(self): self.input = InputSignals() self.output = OutputSignals() - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ Disconnect all connections in input and output signals. @@ -264,7 +279,7 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: """ return self.input.disconnect() + self.output.disconnect() - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property @@ -326,14 +341,14 @@ def connected(self) -> bool: return self.inputs.connected or self.outputs.connected or self.signals.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return ( self.inputs.fully_connected and self.outputs.fully_connected and self.signals.fully_connected ) - def disconnect(self): + def disconnect(self) -> list[tuple[Channel, Channel]]: """ Disconnect all connections belonging to inputs, outputs, and signals channels. @@ -360,7 +375,7 @@ def deactivate_strict_hints(self): def _connect_output_signal(self, signal: OutputSignal): self.signals.input.run.connect(signal) - def __rshift__(self, other: InputSignal | HasIO): + def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: """ Allows users to connect run and ran signals like: `first >> second`. """ @@ -458,8 +473,8 @@ def copy_io( try: self._copy_values(other, fail_hard=values_fail_hard) except Exception as e: - for this, other in new_connections: - this.disconnect(other) + for owned, conjugate in new_connections: + owned.disconnect(conjugate) raise e def _copy_connections( @@ -522,7 +537,7 @@ def _copy_values( self, other: HasIO, fail_hard: bool = False, - ) -> list[tuple[Channel, Any]]: + ) -> list[tuple[DataChannel, Any]]: """ Copies all data from input and output channels in the other object onto this one. From c77bcbd6483504b8278dde035edc8547d014854e Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 10 Jan 2025 10:04:00 -0800 Subject: [PATCH 14/43] Refactor connection validity The instance check to see if a connection candidate has the correct (conjugate) type now occurs only _once_ in the parent `Channel` class. `Channel._valid_connection` is the repurposed to check for validity inside the scope of the classes already lining up, and defaults to simply returning `True` in the base class. `DataChannel` overrides it to do the type hint comparison. Changes inspired by [conversation](https://github.com/pyiron/pyiron_workflow/pull/533#discussion_r1908526844) with @XzzX. Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 66 +++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index af602a31..f96bb48c 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -118,12 +118,6 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - @abstractmethod - def _valid_connection(self, other: object) -> bool: - """ - Logic for determining if a connection is valid. - """ - def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. @@ -146,22 +140,30 @@ def connect(self, *others: ConjugateType) -> None: for other in others: if other in self.connections: continue - elif self._valid_connection(other): - # Prepend new connections - # so that connection searches run newest to oldest - self.connections.insert(0, other) - other.connections.insert(0, self) - else: - if isinstance(other, self.connection_conjugate()): + elif isinstance(other, self.connection_conjugate()): + if self._valid_connection(other): + # Prepend new connections + # so that connection searches run newest to oldest + self.connections.insert(0, other) + other.connections.insert(0, self) + else: raise ChannelConnectionError( self._connection_conjugate_failure_message(other) ) from None - else: - raise TypeError( - f"Can only connect to {self.connection_conjugate()} " - f"objects, but {self.full_label} ({self.__class__}) " - f"got {other} ({type(other)})" - ) + else: + raise TypeError( + f"Can only connect to {self.connection_conjugate()} " + f"objects, but {self.full_label} ({self.__class__}) " + f"got {other} ({type(other)})" + ) + + def _valid_connection(self, other: ConjugateType) -> bool: + """ + Logic for determining if a connection to a conjugate partner is valid. + + Override in child classes as necessary. + """ + return True def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: return ( @@ -466,21 +468,18 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: object) -> bool: - if isinstance(other, self.connection_conjugate()): - if self._both_typed(other): - out, inp = self._figure_out_who_is_who(other) - if not inp.strict_hints: - return True - else: - return type_hint_is_as_or_more_specific_than( - out.type_hint, inp.type_hint - ) - else: - # If either is untyped, don't do type checking + def _valid_connection(self, other: DataChannel) -> bool: + if self._both_typed(other): + out, inp = self._figure_out_who_is_who(other) + if not inp.strict_hints: return True + else: + return type_hint_is_as_or_more_specific_than( + out.type_hint, inp.type_hint + ) else: - return False + # If either is untyped, don't do type checking + return True def _connection_conjugate_failure_message(self, other: DataChannel) -> str: msg = super()._connection_conjugate_failure_message(other) @@ -599,9 +598,6 @@ class SignalChannel(FlavorChannel[SignalType], ABC): def __call__(self) -> None: pass - def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_conjugate()) - class BadCallbackError(ValueError): pass From ff6a984cfec3301729c898f85dc6dbd1679335a3 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 16:06:53 -0800 Subject: [PATCH 15/43] `mypy` run (#541) * Hint init properties Signed-off-by: liamhuber * Hint local function Signed-off-by: liamhuber * Add stricter return and hint Signed-off-by: liamhuber * :bug: Hint tuple[] not () Signed-off-by: liamhuber * Black Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/run.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b704abc7..b7b11c7b 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -7,6 +7,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Executor as StdLibExecutor from concurrent.futures import Future, ThreadPoolExecutor from functools import partial @@ -51,14 +52,14 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.running = False - self.failed = False - self.executor = None - # We call it an executor, but it's just whether to use one. - # This is a simply stop-gap as we work out more sophisticated ways to reference - # (or create) an executor process without ever trying to pickle a `_thread.lock` + self.running: bool = False + self.failed: bool = False + self.executor: ( + StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict] | None + ) = None + # We call it an executor, but it can also be instructions on making one self.future: None | Future = None - self._thread_pool_sleep_time = 1e-6 + self._thread_pool_sleep_time: float = 1e-6 @abstractmethod def on_run(self, *args, **kwargs) -> Any: # callable[..., Any | tuple]: @@ -135,7 +136,7 @@ def run( :attr:`running`. (Default is True.) """ - def _none_to_dict(inp): + def _none_to_dict(inp: dict | None) -> dict: return {} if inp is None else inp before_run_kwargs = _none_to_dict(before_run_kwargs) @@ -275,7 +276,7 @@ def _finish_run( run_exception_kwargs: dict, run_finally_kwargs: dict, **kwargs, - ) -> Any | tuple: + ) -> Any | tuple | None: """ Switch the status, then process and return the run result. """ @@ -288,6 +289,7 @@ def _finish_run( self._run_exception(**run_exception_kwargs) if raise_run_exceptions: raise e + return None finally: self._run_finally(**run_finally_kwargs) @@ -308,7 +310,7 @@ def _readiness_error_message(self) -> str: @staticmethod def _parse_executor( - executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict), + executor: StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict], ) -> StdLibExecutor: """ If you've already got an executor, you're done. But if you get callable and From 3577158224af0707f2a6433e780482be027e6a14 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 16:12:16 -0800 Subject: [PATCH 16/43] `mypy` topology and find (#542) * Don't overload typed variable Signed-off-by: liamhuber * Add (and more specific) return hint(s) To the one function missing one Signed-off-by: liamhuber * Add module docstring Signed-off-by: liamhuber * Catch module spec failures Signed-off-by: liamhuber * Force mypy to accept the design feature That we _want_ callers to be able to get abstract classes if they request them Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff import sort Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/find.py | 17 ++++++++++++++--- pyiron_workflow/topology.py | 14 +++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pyiron_workflow/find.py b/pyiron_workflow/find.py index eea058a3..7111ef47 100644 --- a/pyiron_workflow/find.py +++ b/pyiron_workflow/find.py @@ -1,3 +1,9 @@ +""" +A utility for finding public `pyiron_workflow.node.Node` objects. + +Supports the idea of node developers writing independent node packages. +""" + from __future__ import annotations import importlib.util @@ -5,23 +11,28 @@ import sys from pathlib import Path from types import ModuleType +from typing import TypeVar, cast from pyiron_workflow.node import Node +NodeType = TypeVar("NodeType", bound=Node) + def _get_subclasses( source: str | Path | ModuleType, - base_class: type, + base_class: type[NodeType], get_private: bool = False, get_abstract: bool = False, get_imports_too: bool = False, -): +) -> list[type[NodeType]]: if isinstance(source, str | Path): source = Path(source) if source.is_file(): # Load the module from the file module_name = source.stem spec = importlib.util.spec_from_file_location(module_name, str(source)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not create a ModuleSpec for {source}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) @@ -54,4 +65,4 @@ def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: """ Get a list of all public, non-abstract nodes defined in the source. """ - return _get_subclasses(source, Node) + return cast(list[type[Node]], _get_subclasses(source, Node)) diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index bffd590f..a621cc20 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -12,7 +12,7 @@ from toposort import CircularDependencyError, toposort, toposort_flatten if TYPE_CHECKING: - from pyiron_workflow.channels import SignalChannel + from pyiron_workflow.channels import InputSignal, OutputSignal from pyiron_workflow.node import Node @@ -75,8 +75,8 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: ) locally_scoped_dependencies.append(upstream.owner.label) node_dependencies.extend(locally_scoped_dependencies) - node_dependencies = set(node_dependencies) - if node.label in node_dependencies: + node_dependencies_set = set(node_dependencies) + if node.label in node_dependencies_set: # the toposort library has a # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) # That self-dependency isn't caught, so we catch it manually here. @@ -85,14 +85,14 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: f"the execution of non-DAGs: {node.full_label} appears in its own " f"input." ) - digraph[node.label] = node_dependencies + digraph[node.label] = node_dependencies_set return digraph def _set_new_run_connections_with_fallback_recovery( connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] -): +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a function that takes a dictionary of unconnected nodes, connects their execution graph, and returns the new starting nodes, this wrapper makes sure that @@ -144,7 +144,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list def set_run_connections_according_to_linear_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all @@ -196,7 +196,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]: def set_run_connections_according_to_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all From 9c260ddd91168ccf137f0e70d35cdf1b4be7bf31 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 09:40:31 -0800 Subject: [PATCH 17/43] `mypy` semantics (#538) * Initialize _label to a string Signed-off-by: liamhuber * Hint the delimiter Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Purge `ParentMost` If subclasses of `Semantic` want to limit their `parent` attribute beyond the standard requirement that it be a `SemanticParent`, they can handle that by overriding the `parent` setter and getter. The only place this was used was in `Workflow`, and so such handling is now exactly the case. Signed-off-by: liamhuber * Update comment Signed-off-by: liamhuber * Use generic type Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Jump through mypy hoops It doesn't recognize the __set__ for fset methods on the property, so my usual routes for super'ing the setter are failing. This is annoying, but I don't see it being particularly harmful as the method is private. Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Add dev note Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/semantics.py | 87 ++++++++++++++---------------- pyiron_workflow/nodes/composite.py | 8 +-- pyiron_workflow/workflow.py | 21 +++++++- tests/unit/mixin/test_semantics.py | 25 ++++----- tests/unit/test_workflow.py | 5 +- 5 files changed, 74 insertions(+), 72 deletions(-) diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index de083b87..e207ab92 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -13,9 +13,10 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path +from typing import Generic, TypeVar from bidict import bidict @@ -31,12 +32,12 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): accessible. """ - semantic_delimiter = "/" + semantic_delimiter: str = "/" def __init__( self, label: str, *args, parent: SemanticParent | None = None, **kwargs ): - self._label = None + self._label = "" self._parent = None self._detached_parent_path = None self.label = label @@ -61,6 +62,13 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: + self._set_parent(new_parent) + + def _set_parent(self, new_parent: SemanticParent | None): + """ + mypy is uncooperative with super calls for setters, so we pull the behaviour + out. + """ if new_parent is self._parent: # Exit early if nothing is changing return @@ -157,7 +165,10 @@ class CyclicPathError(ValueError): """ -class SemanticParent(Semantic, ABC): +ChildType = TypeVar("ChildType", bound=Semantic) + + +class SemanticParent(Semantic, Generic[ChildType], ABC): """ A semantic object with a collection of uniquely-named semantic children. @@ -182,19 +193,29 @@ def __init__( strict_naming: bool = True, **kwargs, ): - self._children = bidict() + self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming super().__init__(*args, label=label, parent=parent, **kwargs) + @classmethod + @abstractmethod + def child_type(cls) -> type[ChildType]: + # Dev note: In principle, this could be a regular attribute + # However, in other situations this is precluded (e.g. in channels) + # since it would result in circular references. + # Here we favour consistency over brevity, + # and maintain the X_type() class method pattern + pass + @property - def children(self) -> bidict[str, Semantic]: + def children(self) -> bidict[str, ChildType]: return self._children @property def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) - def __getattr__(self, key): + def __getattr__(self, key) -> ChildType: try: return self._children[key] except KeyError as key_error: @@ -210,7 +231,7 @@ def __getattr__(self, key): def __iter__(self): return self.children.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.children) def __dir__(self): @@ -218,15 +239,15 @@ def __dir__(self): def add_child( self, - child: Semantic, + child: ChildType, label: str | None = None, strict_naming: bool | None = None, - ) -> Semantic: + ) -> ChildType: """ Add a child, optionally assigning it a new label in the process. Args: - child (Semantic): The child to add. + child (ChildType): The child to add. label (str|None): A (potentially) new label to assign the child. (Default is None, leave the child's label alone.) strict_naming (bool|None): Whether to append a suffix to the label if @@ -234,7 +255,7 @@ def add_child( use the class-level flag.) Returns: - (Semantic): The child being added. + (ChildType): The child being added. Raises: TypeError: When the child is not of an allowed class. @@ -244,18 +265,12 @@ def add_child( `strict_naming` is true. """ - if not isinstance(child, Semantic): + if not isinstance(child, self.child_type()): raise TypeError( - f"{self.label} expected a new child of type {Semantic.__name__} " + f"{self.label} expected a new child of type {self.child_type()} " f"but got {child}" ) - if isinstance(child, ParentMost): - raise ParentMostError( - f"{child.label} is {ParentMost.__name__} and may only take None as a " - f"parent but was added as a child to {self.label}" - ) - self._ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -339,15 +354,15 @@ def _add_suffix_to_label(self, label): ) return new_label - def remove_child(self, child: Semantic | str) -> Semantic: + def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): child = self.children.pop(child) - elif isinstance(child, Semantic): + elif isinstance(child, self.child_type()): self.children.inv.pop(child) else: raise TypeError( f"{self.label} expected to remove a child of type str or " - f"{Semantic.__name__} but got {child}" + f"{self.child_type()} but got {child}" ) child.parent = None @@ -361,7 +376,7 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: self._ensure_path_is_not_cyclic(new_parent, self) - super(SemanticParent, type(self)).parent.__set__(self, new_parent) + self._set_parent(new_parent) def __getstate__(self): state = super().__getstate__() @@ -396,27 +411,3 @@ def __setstate__(self, state): # children). So, now return their parent to them: for child in self: child.parent = self - - -class ParentMostError(TypeError): - """ - To be raised when assigning a parent to a parent-most object - """ - - -class ParentMost(SemanticParent, ABC): - """ - A semantic parent that cannot have any other parent. - """ - - @property - def parent(self) -> None: - return None - - @parent.setter - def parent(self, new_parent: None): - if new_parent is not None: - raise ParentMostError( - f"{self.label} is {ParentMost.__name__} and may only take None as a " - f"parent but got {type(new_parent)}" - ) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 11d50583..e3a06cab 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -54,7 +54,7 @@ class FailedChildError(RuntimeError): """Raise when one or more child nodes raise exceptions.""" -class Composite(SemanticParent, HasCreator, Node, ABC): +class Composite(SemanticParent[Node], HasCreator, Node, ABC): """ A base class for nodes that have internal graph structure -- i.e. they hold a collection of child nodes and their computation is to execute that graph. @@ -154,6 +154,10 @@ def __init__( **kwargs, ) + @classmethod + def child_type(cls) -> type[Node]: + return Node + def activate_strict_hints(self): super().activate_strict_hints() for node in self: @@ -420,8 +424,6 @@ def executor_shutdown(self, wait=True, *, cancel_futures=False): def __setattr__(self, key: str, node: Node): if isinstance(node, Composite) and key in ["_parent", "parent"]: # This is an edge case for assigning a node to an attribute - # We either defer to the setter with super, or directly assign the private - # variable (as requested in the setter) super().__setattr__(key, node) elif isinstance(node, Node): self.add_child(node, label=key) diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 791e17c8..8a5ddb29 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -11,7 +11,6 @@ from bidict import bidict from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.mixin.semantics import ParentMost from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -20,7 +19,13 @@ from pyiron_workflow.storage import StorageInterface -class Workflow(ParentMost, Composite): +class ParentMostError(TypeError): + """ + To be raised when assigning a parent to a parent-most object + """ + + +class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of nodes (a subgraph) which can be dynamically modified (adding and removing nodes, @@ -495,3 +500,15 @@ def replace_child( raise e return owned_node + + @property + def parent(self) -> None: + return None + + @parent.setter + def parent(self, new_parent: None): + if new_parent is not None: + raise ParentMostError( + f"{self.label} is a {self.__class__} and may only take None as a " + f"parent but got {type(new_parent)}" + ) diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index dd40c5f0..0b63b94f 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -3,18 +3,23 @@ from pyiron_workflow.mixin.semantics import ( CyclicPathError, - ParentMost, Semantic, SemanticParent, ) +class ConcreteParent(SemanticParent[Semantic]): + @classmethod + def child_type(cls) -> type[Semantic]: + return Semantic + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ParentMost("root") + self.root = ConcreteParent("root") self.child1 = Semantic("child1", parent=self.root) - self.middle1 = SemanticParent("middle", parent=self.root) - self.middle2 = SemanticParent("middle_sub", parent=self.middle1) + self.middle1 = ConcreteParent("middle", parent=self.root) + self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) self.child2 = Semantic("child2", parent=self.middle2) def test_getattr(self): @@ -58,18 +63,6 @@ def test_parent(self): self.assertEqual(self.child1.parent, self.root) self.assertEqual(self.root.parent, None) - with self.subTest(f"{ParentMost.__name__} exceptions"): - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't have parent" - ): - self.root.parent = SemanticParent(label="foo") - - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't be children" - ): - some_parent = SemanticParent(label="bar") - some_parent.add_child(self.root) - with self.subTest("Cyclicity exceptions"): with self.assertRaises(CyclicPathError): self.middle1.parent = self.middle2 diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 174013d3..f19032b7 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -9,9 +9,8 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.semantics import ParentMostError from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import Workflow +from pyiron_workflow.workflow import ParentMostError, Workflow ensure_tests_in_python_path() @@ -155,7 +154,7 @@ def test_io_map_bijectivity(self): self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") - def test_is_parentmost(self): + def test_takes_no_parent(self): wf = Workflow("wf") wf2 = Workflow("wf2") From acc8739047721b0484e3a850fc9c7246b5744876 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 09:58:59 -0800 Subject: [PATCH 18/43] Semantics generic parent (#544) * Make SemanticParent a Generic Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/has_interface_mixins.py | 13 +-- pyiron_workflow/mixin/semantics.py | 85 ++++++++++--------- pyiron_workflow/node.py | 8 +- pyiron_workflow/nodes/composite.py | 5 -- tests/unit/mixin/test_semantics.py | 29 ++++--- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 2828ce7e..72943176 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: from pyiron_workflow.channels import Channel @@ -53,17 +53,6 @@ def full_label(self) -> str: return self.label -class HasParent(ABC): - """ - A mixin to guarantee the parent interface exists. - """ - - @property - @abstractmethod - def parent(self) -> Any: - """A parent for the object.""" - - class HasChannel(ABC): """ A mix-in class for use with the :class:`Channel` class. diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index e207ab92..8c1ef6d4 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -2,10 +2,11 @@ Classes for "semantic" reasoning. The motivation here is to be able to provide the object with a unique identifier -in the context of other semantic objects. Each object may have exactly one parent -and an arbitrary number of children, and each child's name must be unique in the -scope of that parent. In this way, the path from the parent-most object to any -child is completely unique. The typical filesystem on a computer is an excellent +in the context of other semantic objects. Each object may have at most one parent, +while semantic parents may have an arbitrary number of children, and each child's name +must be unique in the scope of that parent. In this way, when semantic parents are also +themselves semantic, we can build a path from the parent-most object to any child that +is completely unique. The typical filesystem on a computer is an excellent example and fulfills our requirements, the only reason we depart from it is so that we are free to have objects stored in different locations (possibly even on totally different drives or machines) belong to the same semantic group. @@ -21,10 +22,12 @@ from bidict import bidict from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.has_interface_mixins import HasLabel, HasParent, UsesState +from pyiron_workflow.mixin.has_interface_mixins import HasLabel, UsesState +ParentType = TypeVar("ParentType", bound="SemanticParent") -class Semantic(UsesState, HasLabel, HasParent, ABC): + +class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): """ An object with a unique semantic path. @@ -34,9 +37,7 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): semantic_delimiter: str = "/" - def __init__( - self, label: str, *args, parent: SemanticParent | None = None, **kwargs - ): + def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs): self._label = "" self._parent = None self._detached_parent_path = None @@ -44,6 +45,11 @@ def __init__( self.parent = parent super().__init__(*args, **kwargs) + @classmethod + @abstractmethod + def parent_type(cls) -> type[ParentType]: + pass + @property def label(self) -> str: return self._label @@ -57,14 +63,14 @@ def label(self, new_label: str) -> None: self._label = new_label @property - def parent(self) -> SemanticParent | None: + def parent(self) -> ParentType | None: return self._parent @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: + def parent(self, new_parent: ParentType | None) -> None: self._set_parent(new_parent) - def _set_parent(self, new_parent: SemanticParent | None): + def _set_parent(self, new_parent: ParentType | None): """ mypy is uncooperative with super calls for setters, so we pull the behaviour out. @@ -73,12 +79,14 @@ def _set_parent(self, new_parent: SemanticParent | None): # Exit early if nothing is changing return - if new_parent is not None and not isinstance(new_parent, SemanticParent): + if new_parent is not None and not isinstance(new_parent, self.parent_type()): raise ValueError( - f"Expected None or a {SemanticParent.__name__} for the parent of " + f"Expected None or a {self.parent_type()} for the parent of " f"{self.label}, but got {new_parent}" ) + _ensure_path_is_not_cyclic(new_parent, self) + if ( self._parent is not None and new_parent is not self._parent @@ -134,7 +142,10 @@ def full_label(self) -> str: @property def semantic_root(self) -> Semantic: """The parent-most object in this semantic path; may be self.""" - return self.parent.semantic_root if isinstance(self.parent, Semantic) else self + if isinstance(self.parent, Semantic): + return self.parent.semantic_root + else: + return self def as_path(self, root: Path | str | None = None) -> Path: """ @@ -168,9 +179,9 @@ class CyclicPathError(ValueError): ChildType = TypeVar("ChildType", bound=Semantic) -class SemanticParent(Semantic, Generic[ChildType], ABC): +class SemanticParent(Generic[ChildType], ABC): """ - A semantic object with a collection of uniquely-named semantic children. + An with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -187,15 +198,14 @@ class SemanticParent(Semantic, Generic[ChildType], ABC): def __init__( self, - label: str, + label: str | None, # Vestigial while the label order is broken *args, - parent: SemanticParent | None = None, strict_naming: bool = True, **kwargs, ): self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, parent=parent, **kwargs) + super().__init__(*args, label=label, **kwargs) @classmethod @abstractmethod @@ -271,7 +281,7 @@ def add_child( f"but got {child}" ) - self._ensure_path_is_not_cyclic(self, child) + _ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -292,18 +302,6 @@ def add_child( child.parent = self return child - @staticmethod - def _ensure_path_is_not_cyclic(parent: SemanticParent | None, child: Semantic): - if parent is not None and parent.semantic_path.startswith( - child.semantic_path + child.semantic_delimiter - ): - raise CyclicPathError( - f"{parent.label} cannot be the parent of {child.label}, because its " - f"semantic path is already in {child.label}'s path and cyclic paths " - f"are not allowed. (i.e. {child.semantic_path} is in " - f"{parent.semantic_path})" - ) - def _ensure_child_has_no_other_parent(self, child: Semantic): if child.parent is not None and child.parent is not self: raise ValueError( @@ -369,15 +367,6 @@ def remove_child(self, child: ChildType | str) -> ChildType: return child - @property - def parent(self) -> SemanticParent | None: - return self._parent - - @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: - self._ensure_path_is_not_cyclic(new_parent, self) - self._set_parent(new_parent) - def __getstate__(self): state = super().__getstate__() @@ -411,3 +400,15 @@ def __setstate__(self, state): # children). So, now return their parent to them: for child in self: child.parent = self + + +def _ensure_path_is_not_cyclic(parent, child: Semantic): + if isinstance(parent, Semantic) and parent.semantic_path.startswith( + child.semantic_path + child.semantic_delimiter + ): + raise CyclicPathError( + f"{parent.label} cannot be the parent of {child.label}, because its " + f"semantic path is already in {child.label}'s path and cyclic paths " + f"are not allowed. (i.e. {child.semantic_path} is in " + f"{parent.semantic_path})" + ) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 3b86a5e4..6e19a704 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -41,7 +41,7 @@ class Node( HasIOWithInjection, - Semantic, + Semantic["Composite"], Runnable, ExploitsSingleOutput, ABC, @@ -319,6 +319,12 @@ def __init__( **kwargs, ) + @classmethod + def parent_type(cls) -> type[Composite]: + from pyiron_workflow.nodes.composite import Composite + + return Composite + def _setup_node(self) -> None: """ Called _before_ :meth:`Node.__init__` finishes. diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index e3a06cab..e5e05ed4 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -304,11 +304,6 @@ def add_child( label: str | None = None, strict_naming: bool | None = None, ) -> Node: - if not isinstance(child, Node): - raise TypeError( - f"Only new {Node.__name__} instances may be added, but got " - f"{type(child)}." - ) self._cached_inputs = None # Reset cache after graph change return super().add_child(child, label=label, strict_naming=strict_naming) diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index 0b63b94f..874928f7 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pathlib import Path @@ -8,19 +10,25 @@ ) -class ConcreteParent(SemanticParent[Semantic]): +class ConcreteSemantic(Semantic["ConcreteParent"]): + @classmethod + def parent_type(cls) -> type[ConcreteParent]: + return ConcreteParent + + +class ConcreteParent(SemanticParent[ConcreteSemantic], ConcreteSemantic): @classmethod - def child_type(cls) -> type[Semantic]: - return Semantic + def child_type(cls) -> type[ConcreteSemantic]: + return ConcreteSemantic class TestSemantics(unittest.TestCase): def setUp(self): self.root = ConcreteParent("root") - self.child1 = Semantic("child1", parent=self.root) + self.child1 = ConcreteSemantic("child1", parent=self.root) self.middle1 = ConcreteParent("middle", parent=self.root) self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) - self.child2 = Semantic("child2", parent=self.middle2) + self.child2 = ConcreteSemantic("child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -40,18 +48,19 @@ def test_getattr(self): def test_label_validity(self): with self.assertRaises(TypeError, msg="Label must be a string"): - Semantic(label=123) + ConcreteSemantic(label=123) def test_label_delimiter(self): with self.assertRaises( - ValueError, msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - Semantic(f"invalid{Semantic.semantic_delimiter}label") + ConcreteSemantic(f"invalid{ConcreteSemantic.semantic_delimiter}label") def test_semantic_delimiter(self): self.assertEqual( "/", - Semantic.semantic_delimiter, + ConcreteSemantic.semantic_delimiter, msg="This is just a hard-code to the current value, update it freely so " "the test passes; if it fails it's just a reminder that your change is " "not backwards compatible, and the next release number should reflect " @@ -105,7 +114,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = Semantic("orphan") + orphan = ConcreteSemantic("orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent" From 794291076302c0d04e58f1bfb563791eb57450e5 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 10:06:57 -0800 Subject: [PATCH 19/43] Improvements to semantic labeling (#547) * Initialize _label to a string Signed-off-by: liamhuber * Hint the delimiter Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Purge `ParentMost` If subclasses of `Semantic` want to limit their `parent` attribute beyond the standard requirement that it be a `SemanticParent`, they can handle that by overriding the `parent` setter and getter. The only place this was used was in `Workflow`, and so such handling is now exactly the case. Signed-off-by: liamhuber * Update comment Signed-off-by: liamhuber * Use generic type Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Jump through mypy hoops It doesn't recognize the __set__ for fset methods on the property, so my usual routes for super'ing the setter are failing. This is annoying, but I don't see it being particularly harmful as the method is private. Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Add dev note Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber * Guarantee that semantic parents have a label Signed-off-by: liamhuber * :bug: don't assume parents have semantic_path But we can now safely assume they have a label Signed-off-by: liamhuber * Pull label default up into Semantic This way it is allowed to be a keyword argument everywhere, except for Workflow which makes it positional and adjusts its `super().__init__` call accordingly. Signed-off-by: liamhuber * Refactor: label validity check Pull it up from semantic into an extensible method on the mixin class Signed-off-by: liamhuber * Refactor: rename class Signed-off-by: liamhuber * Add label restrictions To semantic parent based on its child type's semantic delimiter Signed-off-by: liamhuber * Improve error messages Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber * Annotate some extra returns (#548) Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 4 -- pyiron_workflow/mixin/has_interface_mixins.py | 16 ++++- pyiron_workflow/mixin/semantics.py | 68 +++++++++++-------- pyiron_workflow/node.py | 5 +- pyiron_workflow/nodes/composite.py | 2 +- tests/unit/mixin/test_run.py | 4 +- tests/unit/mixin/test_semantics.py | 33 ++++++--- 7 files changed, 82 insertions(+), 50 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 82970878..153932fd 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -92,10 +92,6 @@ def __init__( self.owner: HasIO = owner self.connections: list[ConjugateType] = [] - @property - def label(self) -> str: - return self._label - @abstractmethod def __str__(self): pass diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 72943176..5183c83e 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -39,10 +39,24 @@ class HasLabel(ABC): A mixin to guarantee the label interface exists. """ + _label: str + @property - @abstractmethod def label(self) -> str: """A label for the object.""" + return self._label + + @label.setter + def label(self, new_label: str): + self._check_label(new_label) + self._label = new_label + + def _check_label(self, new_label: str) -> None: + """ + Extensible checking routine for label validity. + """ + if not isinstance(new_label, str): + raise TypeError(f"Expected a string label but got {new_label}") @property def full_label(self) -> str: diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index 8c1ef6d4..1ecfd8fb 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path -from typing import Generic, TypeVar +from typing import ClassVar, Generic, TypeVar from bidict import bidict @@ -35,13 +35,19 @@ class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): accessible. """ - semantic_delimiter: str = "/" + semantic_delimiter: ClassVar[str] = "/" - def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs): + def __init__( + self, + *args, + label: str | None = None, + parent: ParentType | None = None, + **kwargs, + ): self._label = "" self._parent = None self._detached_parent_path = None - self.label = label + self.label = self.__class__.__name__ if label is None else label self.parent = parent super().__init__(*args, **kwargs) @@ -50,17 +56,13 @@ def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs def parent_type(cls) -> type[ParentType]: pass - @property - def label(self) -> str: - return self._label - - @label.setter - def label(self, new_label: str) -> None: - if not isinstance(new_label, str): - raise TypeError(f"Expected a string label but got {new_label}") + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) if self.semantic_delimiter in new_label: - raise ValueError(f"{self.semantic_delimiter} cannot be in the label") - self._label = new_label + raise ValueError( + f"Semantic delimiter {self.semantic_delimiter} cannot be in new label " + f"{new_label}" + ) @property def parent(self) -> ParentType | None: @@ -104,18 +106,22 @@ def semantic_path(self) -> str: The path of node labels from the graph root (parent-most node) down to this node. """ + prefix: str if self.parent is None and self.detached_parent_path is None: prefix = "" elif self.parent is None and self.detached_parent_path is not None: prefix = self.detached_parent_path elif self.parent is not None and self.detached_parent_path is None: - prefix = self.parent.semantic_path + if isinstance(self.parent, Semantic): + prefix = self.parent.semantic_path + else: + prefix = self.semantic_delimiter + self.parent.label else: raise ValueError( f"The parent and detached path should not be able to take non-None " f"values simultaneously, but got {self.parent} and " - f"{self.detached_parent_path}, respectively. Please raise an issue on GitHub " - f"outlining how your reached this state." + f"{self.detached_parent_path}, respectively. Please raise an issue on " + f"GitHub outlining how your reached this state." ) return prefix + self.semantic_delimiter + self.label @@ -179,9 +185,9 @@ class CyclicPathError(ValueError): ChildType = TypeVar("ChildType", bound=Semantic) -class SemanticParent(Generic[ChildType], ABC): +class SemanticParent(HasLabel, Generic[ChildType], ABC): """ - An with a collection of uniquely-named semantic children. + A labeled object with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -198,14 +204,13 @@ class SemanticParent(Generic[ChildType], ABC): def __init__( self, - label: str | None, # Vestigial while the label order is broken *args, strict_naming: bool = True, **kwargs, ): self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, **kwargs) + super().__init__(*args, **kwargs) @classmethod @abstractmethod @@ -225,6 +230,15 @@ def children(self) -> bidict[str, ChildType]: def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) + if self.child_type().semantic_delimiter in new_label: + raise ValueError( + f"Child type ({self.child_type()}) semantic delimiter " + f"{self.child_type().semantic_delimiter} cannot be in new label " + f"{new_label}" + ) + def __getattr__(self, key) -> ChildType: try: return self._children[key] @@ -302,7 +316,7 @@ def add_child( child.parent = self return child - def _ensure_child_has_no_other_parent(self, child: Semantic): + def _ensure_child_has_no_other_parent(self, child: Semantic) -> None: if child.parent is not None and child.parent is not self: raise ValueError( f"The child ({child.label}) already belongs to the parent " @@ -310,17 +324,17 @@ def _ensure_child_has_no_other_parent(self, child: Semantic): f"add it to this parent ({self.label})." ) - def _this_child_is_already_at_this_label(self, child: Semantic, label: str): + def _this_child_is_already_at_this_label(self, child: Semantic, label: str) -> bool: return ( label == child.label and label in self.child_labels and self.children[label] is child ) - def _this_child_is_already_at_a_different_label(self, child, label): + def _this_child_is_already_at_a_different_label(self, child, label) -> bool: return child.parent is self and label != child.label - def _get_unique_label(self, label: str, strict_naming: bool): + def _get_unique_label(self, label: str, strict_naming: bool) -> str: if label in self.__dir__(): if label in self.child_labels: if strict_naming: @@ -337,7 +351,7 @@ def _get_unique_label(self, label: str, strict_naming: bool): ) return label - def _add_suffix_to_label(self, label): + def _add_suffix_to_label(self, label: str) -> str: i = 0 new_label = label while new_label in self.__dir__(): @@ -402,7 +416,7 @@ def __setstate__(self, state): child.parent = self -def _ensure_path_is_not_cyclic(parent, child: Semantic): +def _ensure_path_is_not_cyclic(parent, child: Semantic) -> None: if isinstance(parent, Semantic) and parent.semantic_path.startswith( child.semantic_path + child.semantic_delimiter ): diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 6e19a704..ead473f7 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -297,10 +297,7 @@ def __init__( **kwargs: Interpreted as node input data, with keys corresponding to channel labels. """ - super().__init__( - label=self.__class__.__name__ if label is None else label, - parent=parent, - ) + super().__init__(label=label, parent=parent) self.checkpoint = checkpoint self.recovery: Literal["pickle"] | StorageInterface | None = "pickle" self._serialize_result = False # Advertised, but private to indicate diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index e5e05ed4..1689da92 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -143,8 +143,8 @@ def __init__( # empty but the running_children list is not super().__init__( - label, *args, + label=label, parent=parent, delete_existing_savefiles=delete_existing_savefiles, autoload=autoload, diff --git a/tests/unit/mixin/test_run.py b/tests/unit/mixin/test_run.py index 009e03b5..bfa32ece 100644 --- a/tests/unit/mixin/test_run.py +++ b/tests/unit/mixin/test_run.py @@ -8,9 +8,7 @@ class ConcreteRunnable(Runnable): - @property - def label(self) -> str: - return "child_class_with_all_methods_implemented" + _label = "child_class_with_all_methods_implemented" def on_run(self, **kwargs): return kwargs diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index 874928f7..fbb9bac4 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -12,23 +12,29 @@ class ConcreteSemantic(Semantic["ConcreteParent"]): @classmethod - def parent_type(cls) -> type[ConcreteParent]: - return ConcreteParent + def parent_type(cls) -> type[ConcreteSemanticParent]: + return ConcreteSemanticParent -class ConcreteParent(SemanticParent[ConcreteSemantic], ConcreteSemantic): +class ConcreteParent(SemanticParent[ConcreteSemantic]): + _label = "concrete_parent_default_label" + @classmethod def child_type(cls) -> type[ConcreteSemantic]: return ConcreteSemantic +class ConcreteSemanticParent(ConcreteParent, ConcreteSemantic): + pass + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ConcreteParent("root") - self.child1 = ConcreteSemantic("child1", parent=self.root) - self.middle1 = ConcreteParent("middle", parent=self.root) - self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) - self.child2 = ConcreteSemantic("child2", parent=self.middle2) + self.root = ConcreteSemanticParent(label="root") + self.child1 = ConcreteSemantic(label="child1", parent=self.root) + self.middle1 = ConcreteSemanticParent(label="middle", parent=self.root) + self.middle2 = ConcreteSemanticParent(label="middle_sub", parent=self.middle1) + self.child2 = ConcreteSemantic(label="child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -55,7 +61,14 @@ def test_label_delimiter(self): ValueError, msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - ConcreteSemantic(f"invalid{ConcreteSemantic.semantic_delimiter}label") + ConcreteSemantic(label=f"invalid{ConcreteSemantic.semantic_delimiter}label") + + non_semantic_parent = ConcreteParent() + with self.assertRaises( + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", + ): + non_semantic_parent.label = f"contains_{non_semantic_parent.child_type().semantic_delimiter}_delimiter" def test_semantic_delimiter(self): self.assertEqual( @@ -114,7 +127,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = ConcreteSemantic("orphan") + orphan = ConcreteSemantic(label="orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent" From 52fe191dfcfb00514f8b8076d8d0bb1cd140edcb Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 07:17:32 -0800 Subject: [PATCH 20/43] Make the `HasChannel` interface generic on the `Channel` type (#550) * :bug: hint with [] for type args Signed-off-by: liamhuber * Make a generic version of HasChannel Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/has_interface_mixins.py | 16 ++++++++++++---- pyiron_workflow/mixin/injection.py | 2 +- pyiron_workflow/mixin/single_output.py | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 5183c83e..049159ee 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from pyiron_workflow.channels import Channel @@ -69,9 +69,7 @@ def full_label(self) -> str: class HasChannel(ABC): """ - A mix-in class for use with the :class:`Channel` class. - A :class:`Channel` is able to (attempt to) connect to any child instance of :class:`HasConnection` - by looking at its :attr:`connection` attribute. + A mix-in class for use with the :class:`Channel` class and its children. This is useful for letting channels attempt to connect to non-channel objects directly by pointing them to some channel that object holds. @@ -83,6 +81,16 @@ def channel(self) -> Channel: pass +ChannelType = TypeVar("ChannelType", bound="Channel") + + +class HasGenericChannel(HasChannel, Generic[ChannelType], ABC): + @property + @abstractmethod + def channel(self) -> ChannelType: + pass + + class HasRun(ABC): """ A mixin to guarantee that the :meth:`run` method exists. diff --git a/pyiron_workflow/mixin/injection.py b/pyiron_workflow/mixin/injection.py index 3b22c589..a0cf13e4 100644 --- a/pyiron_workflow/mixin/injection.py +++ b/pyiron_workflow/mixin/injection.py @@ -277,7 +277,7 @@ def __round__(self): class OutputsWithInjection(Outputs): @property - def _channel_class(self) -> type(OutputDataWithInjection): + def _channel_class(self) -> type[OutputDataWithInjection]: return OutputDataWithInjection diff --git a/pyiron_workflow/mixin/single_output.py b/pyiron_workflow/mixin/single_output.py index 4272d870..948a36bb 100644 --- a/pyiron_workflow/mixin/single_output.py +++ b/pyiron_workflow/mixin/single_output.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod -from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel +from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel, HasLabel from pyiron_workflow.mixin.injection import ( OutputDataWithInjection, OutputsWithInjection, @@ -18,7 +18,7 @@ class AmbiguousOutputError(ValueError): """Raised when searching for exactly one output, but multiple are found.""" -class ExploitsSingleOutput(HasLabel, HasChannel, ABC): +class ExploitsSingleOutput(HasLabel, HasGenericChannel[OutputDataWithInjection], ABC): @property @abstractmethod def outputs(self) -> OutputsWithInjection: From 4c9af5e6b966cae75a8312337cc33622144e56d2 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 09:07:23 -0800 Subject: [PATCH 21/43] Generic `HasIO` classes to specify data output panel types (#551) * :bug: hint with [] for type args Signed-off-by: liamhuber * Make a generic version of HasChannel Signed-off-by: liamhuber * Make HasIO generic on the output panel Signed-off-by: liamhuber * Refactor: introduce generic data outputs panel Signed-off-by: liamhuber * Remove unnecessary concrete class To reduce misdirection. We barely use it in the super-class and never need to hint it. In contrast, I kept `OutputsWithInjection` around exactly because it shows up in type hints everywhere, so the shorthand version is nice to have. Signed-off-by: liamhuber * Fix type hints and unused imports Signed-off-by: liamhuber * More return hints (#552) * Fix returned type of __dir__ Conventionally it returns a list, not a set, of strings Signed-off-by: liamhuber * Add hints to io Signed-off-by: liamhuber * Adjust run_finally signature Signed-off-by: liamhuber * Hint user data Signed-off-by: liamhuber * Hint Workflow.automate_execution Signed-off-by: liamhuber * Provide a type-compliant default It never actually matters with the current logic, because of all the checks if parent is None and the fact that it is otherwise hinted to be at least a `Composite`, but it shuts mypy up and it does zero harm. Signed-off-by: liamhuber * black Signed-off-by: liamhuber * `mypy` storage (#553) * Add return hints Signed-off-by: liamhuber * End clause with else Signed-off-by: liamhuber * Explicitly raise an error After narrowing our search to files, actually throw an error right away if you never found one to load. Signed-off-by: liamhuber * Resolve method extension complaints Signed-off-by: liamhuber * `mypy` signature compliance (#554) * Extend runnable signatures Signed-off-by: liamhuber * Align Workflow.run with superclass signature Signed-off-by: liamhuber * Relax FromManyInputs._on_run constraint It was too strict for the DataFrame subclass, so just keep the superclass reference instead of narrowing the constraints. Signed-off-by: liamhuber * black Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/io.py | 94 +++++++++++++++----------- pyiron_workflow/mixin/display_state.py | 3 +- pyiron_workflow/mixin/injection.py | 12 +--- pyiron_workflow/mixin/run.py | 11 +-- pyiron_workflow/mixin/single_output.py | 7 +- pyiron_workflow/node.py | 13 ++-- pyiron_workflow/nodes/macro.py | 5 +- pyiron_workflow/nodes/transform.py | 4 -- pyiron_workflow/storage.py | 34 ++++++---- pyiron_workflow/workflow.py | 25 +++++-- tests/unit/test_io.py | 2 +- tests/unit/test_workflow.py | 8 ++- 12 files changed, 131 insertions(+), 87 deletions(-) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index c3c8dada..e5890884 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,6 +9,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import ItemsView, Iterator from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -59,7 +60,7 @@ class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): channel_dict: DotDict[str, OwnedType] - def __init__(self, *channels: OwnedType): + def __init__(self, *channels: OwnedType) -> None: self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -74,11 +75,11 @@ def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> OwnedType: + def __getattr__(self, item: str) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -88,7 +89,7 @@ def __getattr__(self, item) -> OwnedType: f"nor in its channels ({self.labels})" ) from key_error - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key in self.channel_dict: self._assign_value_to_existing_channel(self.channel_dict[key], value) elif isinstance(value, self._channel_class): @@ -104,16 +105,16 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> OwnedType: + def __getitem__(self, item: str) -> OwnedType: return self.__getattr__(item) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self.__setattr__(key, value) @property @@ -124,11 +125,11 @@ def connections(self) -> list[OwnedConjugate]: ) @property - def connected(self): + def connected(self) -> bool: return any(c.connected for c in self) @property - def fully_connected(self): + def fully_connected(self) -> bool: return all(c.connected for c in self) def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: @@ -145,34 +146,36 @@ def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: return destroyed_connections @property - def labels(self): + def labels(self) -> list[str]: return list(self.channel_dict.keys()) - def items(self): + def items(self) -> ItemsView[str, OwnedType]: return self.channel_dict.items() - def __iter__(self): + def __iter__(self) -> Iterator[OwnedType]: return self.channel_dict.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.channel_dict) def __dir__(self): - return set(super().__dir__() + self.labels) + return list(set(super().__dir__() + self.labels)) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__} {self.labels}" - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # Compatibility with python <3.11 return dict(self.__dict__) - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Because we override getattr, we need to use __dict__ assignment directly in # __setstate__ the same way we need it in __init__ self.__dict__["channel_dict"] = state["channel_dict"] - def display_state(self, state=None, ignore_private=True): + def display_state( + self, state: dict[str, Any] | None = None, ignore_private: bool = True + ) -> dict[str, Any]: state = dict(self.__getstate__()) if state is None else state for k, v in state["channel_dict"].items(): state[k] = v @@ -192,15 +195,15 @@ class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value - def to_value_dict(self): + def to_value_dict(self) -> dict[str, Any]: return {label: channel.value for label, channel in self.channel_dict.items()} - def to_list(self): + def to_list(self) -> list[Any]: """A list of channel values (order not guaranteed)""" return [channel.value for channel in self.channel_dict.values()] @property - def ready(self): + def ready(self) -> bool: return all(c.ready for c in self) def activate_strict_hints(self): @@ -215,19 +218,29 @@ class Inputs(InputsIO, DataIO): def _channel_class(self) -> type[InputData]: return InputData - def fetch(self): + def fetch(self) -> None: for c in self: c.fetch() -class Outputs(OutputsIO, DataIO): +OutputDataType = TypeVar("OutputDataType", bound=OutputData) + + +class GenericOutputs(OutputsIO, DataIO, Generic[OutputDataType], ABC): + @property + @abstractmethod + def _channel_class(self) -> type[OutputDataType]: + pass + + +class Outputs(GenericOutputs[OutputData]): @property def _channel_class(self) -> type[OutputData]: return OutputData class SignalIO(IO[SignalChannel, SignalChannel], ABC): - def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: + def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " f"which is already a {type(channel)}. Only other signal channels may be " @@ -265,9 +278,9 @@ class Signals(HasStateDisplay): output (OutputSignals): An empty input signals IO container. """ - def __init__(self): - self.input = InputSignals() - self.output = OutputSignals() + def __init__(self) -> None: + self.input: InputSignals = InputSignals() + self.output: OutputSignals = OutputSignals() def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ @@ -283,18 +296,21 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property - def connected(self): + def connected(self) -> bool: return self.input.connected or self.output.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return self.input.fully_connected and self.output.fully_connected - def __str__(self): + def __str__(self) -> str: return f"{str(self.input)}\n{str(self.output)}" -class HasIO(HasStateDisplay, HasLabel, HasRun, ABC): +OutputsType = TypeVar("OutputsType", bound=GenericOutputs) + + +class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC): """ A mixin for classes that provide data and signal IO. @@ -303,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, ABC): interface. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._signals = Signals() self._signals.input.run = InputSignal("run", self, self.run) @@ -329,7 +345,7 @@ def data_input_locked(self) -> bool: @property @abstractmethod - def outputs(self) -> Outputs: + def outputs(self) -> OutputsType: pass @property @@ -362,17 +378,17 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: destroyed_connections.extend(self.signals.disconnect()) return destroyed_connections - def activate_strict_hints(self): + def activate_strict_hints(self) -> None: """Enable type hint checks for all data IO""" self.inputs.activate_strict_hints() self.outputs.activate_strict_hints() - def deactivate_strict_hints(self): + def deactivate_strict_hints(self) -> None: """Disable type hint checks for all data IO""" self.inputs.deactivate_strict_hints() self.outputs.deactivate_strict_hints() - def _connect_output_signal(self, signal: OutputSignal): + def _connect_output_signal(self, signal: OutputSignal) -> None: self.signals.input.run.connect(signal) def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: @@ -382,10 +398,12 @@ def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: other._connect_output_signal(self.signals.output.ran) return other - def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal): + def _connect_accumulating_input_signal( + self, signal: AccumulatingInputSignal + ) -> None: self.signals.output.ran.connect(signal) - def __lshift__(self, others): + def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]): """ Connect one or more `ran` signals to `accumulate_and_run` signals like: `this << some_object, another_object, or_by_channel.signals.output.ran` diff --git a/pyiron_workflow/mixin/display_state.py b/pyiron_workflow/mixin/display_state.py index 48309e21..fb9856a8 100644 --- a/pyiron_workflow/mixin/display_state.py +++ b/pyiron_workflow/mixin/display_state.py @@ -4,6 +4,7 @@ from abc import ABC from json import dumps +from typing import Any from pyiron_workflow.mixin.has_interface_mixins import UsesState @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC): def display_state( self, state: dict | None = None, ignore_private: bool = True - ) -> dict: + ) -> dict[str, Any]: """ A dictionary of JSON-compatible objects based on the object state (plus whatever modifications to the state the class designer has chosen to make). diff --git a/pyiron_workflow/mixin/injection.py b/pyiron_workflow/mixin/injection.py index a0cf13e4..2bbf4e48 100644 --- a/pyiron_workflow/mixin/injection.py +++ b/pyiron_workflow/mixin/injection.py @@ -8,11 +8,10 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any from pyiron_workflow.channels import NOT_DATA, OutputData -from pyiron_workflow.io import HasIO, Outputs +from pyiron_workflow.io import GenericOutputs from pyiron_workflow.mixin.has_interface_mixins import HasChannel if TYPE_CHECKING: @@ -275,14 +274,7 @@ def __round__(self): return self._node_injection(Round) -class OutputsWithInjection(Outputs): +class OutputsWithInjection(GenericOutputs[OutputDataWithInjection]): @property def _channel_class(self) -> type[OutputDataWithInjection]: return OutputDataWithInjection - - -class HasIOWithInjection(HasIO, ABC): - @property - @abstractmethod - def outputs(self) -> OutputsWithInjection: - pass diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b7b11c7b..ac7aae86 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -74,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]: Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs). """ - def process_run_result(self, run_output): + def process_run_result(self, run_output: Any) -> Any: """ What to _do_ with the results of :meth:`on_run` once you have them. @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict: **run_kwargs, ) - def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]: + def _before_run( + self, /, check_readiness: bool, *args, **kwargs + ) -> tuple[bool, Any]: """ Things to do _before_ running. @@ -194,6 +196,7 @@ def _run( run_exception_kwargs: dict, run_finally_kwargs: dict, finish_run_kwargs: dict, + *args, **kwargs, ) -> Any | tuple | Future: """ @@ -254,7 +257,7 @@ def _run( ) return self.future - def _run_exception(self, /, **kwargs): + def _run_exception(self, /, *args, **kwargs): """ What to do if an exception is encountered inside :meth:`_run` or :meth:`_finish_run. @@ -262,7 +265,7 @@ def _run_exception(self, /, **kwargs): self.running = False self.failed = True - def _run_finally(self, /, **kwargs): + def _run_finally(self, /, *args, **kwargs): """ What to do after :meth:`_finish_run` (whether an exception is encountered or not), or in :meth:`_run` after an exception is encountered. diff --git a/pyiron_workflow/mixin/single_output.py b/pyiron_workflow/mixin/single_output.py index 948a36bb..1e6dacfc 100644 --- a/pyiron_workflow/mixin/single_output.py +++ b/pyiron_workflow/mixin/single_output.py @@ -7,7 +7,8 @@ from abc import ABC, abstractmethod -from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel, HasLabel +from pyiron_workflow.io import HasIO +from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel from pyiron_workflow.mixin.injection import ( OutputDataWithInjection, OutputsWithInjection, @@ -18,7 +19,9 @@ class AmbiguousOutputError(ValueError): """Raised when searching for exactly one output, but multiple are found.""" -class ExploitsSingleOutput(HasLabel, HasGenericChannel[OutputDataWithInjection], ABC): +class ExploitsSingleOutput( + HasIO[OutputsWithInjection], HasGenericChannel[OutputDataWithInjection], ABC +): @property @abstractmethod def outputs(self) -> OutputsWithInjection: diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index ead473f7..27c033e0 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -19,7 +19,6 @@ from pyiron_workflow.draw import Node as GraphvizNode from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.injection import HasIOWithInjection from pyiron_workflow.mixin.run import ReadinessError, Runnable from pyiron_workflow.mixin.semantics import Semantic from pyiron_workflow.mixin.single_output import ExploitsSingleOutput @@ -40,7 +39,6 @@ class Node( - HasIOWithInjection, Semantic["Composite"], Runnable, ExploitsSingleOutput, @@ -179,8 +177,9 @@ class Node( inputs (pyiron_workflow.io.Inputs): **Abstract.** Children must define a property returning an :class:`Inputs` object. label (str): A name for the node. - outputs (pyiron_workflow.io.Outputs): **Abstract.** Children must define - a property returning an :class:`Outputs` object. + outputs (pyiron_workflow.mixin.injection.OutputsWithInjection): **Abstract.** + Children must define a property returning an :class:`OutputsWithInjection` + object. parent (pyiron_workflow.composite.Composite | None): The parent object owning this, if any. ready (bool): Whether the inputs are all ready and the node is neither @@ -305,7 +304,9 @@ def __init__( self._do_clean: bool = False # Power-user override for cleaning up temporary # serialized results and empty directories (or not). self._cached_inputs = None - self._user_data = {} # A place for power-users to bypass node-injection + + self._user_data: dict[str, Any] = {} + # A place for power-users to bypass node-injection self._setup_node() self._after_node_setup( @@ -630,7 +631,7 @@ def run_data_tree(self, run_parent_trees_too=False) -> None: try: parent_starting_nodes = ( - self.parent.starting_nodes if self.parent is not None else None + self.parent.starting_nodes if self.parent is not None else [] ) # We need these for state recovery later, even if we crash if len(data_tree_starters) == 1 and data_tree_starters[0] is self: diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 527bd5de..d7a3fe53 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -13,8 +13,9 @@ from pyiron_snippets.factory import classfactory -from pyiron_workflow.io import Inputs, Outputs +from pyiron_workflow.io import Inputs from pyiron_workflow.mixin.has_interface_mixins import HasChannel +from pyiron_workflow.mixin.injection import OutputsWithInjection from pyiron_workflow.mixin.preview import ScrapesIO from pyiron_workflow.nodes.composite import Composite from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels @@ -342,7 +343,7 @@ def inputs(self) -> Inputs: return self._inputs @property - def outputs(self) -> Outputs: + def outputs(self) -> OutputsWithInjection: return self._outputs def _parse_remotely_executed_self(self, other_self): diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 8852b426..3ae3218b 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -40,10 +40,6 @@ class FromManyInputs(Transformer, ABC): # Inputs convert to `run_args` as a value dictionary # This must be commensurate with the internal expectations of _on_run - @abstractmethod - def _on_run(self, **inputs_to_value_dict) -> Any: - """Must take inputs kwargs""" - @property def _run_args(self) -> tuple[tuple, dict]: return (), self.inputs.to_value_dict() diff --git a/pyiron_workflow/storage.py b/pyiron_workflow/storage.py index 679f8151..dcdb626c 100644 --- a/pyiron_workflow/storage.py +++ b/pyiron_workflow/storage.py @@ -36,7 +36,7 @@ class StorageInterface(ABC): """ @abstractmethod - def _save(self, node: Node, filename: Path, /, **kwargs): + def _save(self, node: Node, filename: Path, /, *args, **kwargs): """ Save a node to file. @@ -48,7 +48,7 @@ def _save(self, node: Node, filename: Path, /, **kwargs): """ @abstractmethod - def _load(self, filename: Path, /, **kwargs) -> Node: + def _load(self, filename: Path, /, *args, **kwargs) -> Node: """ Instantiate a node from file. @@ -61,7 +61,7 @@ def _load(self, filename: Path, /, **kwargs) -> Node: """ @abstractmethod - def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: + def _has_saved_content(self, filename: Path, /, *args, **kwargs) -> bool: """ Check for a save file matching this storage interface. @@ -74,7 +74,7 @@ def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: """ @abstractmethod - def _delete(self, filename: Path, /, **kwargs): + def _delete(self, filename: Path, /, *args, **kwargs): """ Remove an existing save-file for this backend. @@ -132,7 +132,7 @@ def has_saved_content( node: Node | None = None, filename: str | Path | None = None, **kwargs, - ): + ) -> bool: """ Check if a file has contents related to a node. @@ -168,7 +168,9 @@ def delete( if filename.parent.exists() and not any(filename.parent.iterdir()): filename.parent.rmdir() - def _parse_filename(self, node: Node | None, filename: str | Path | None = None): + def _parse_filename( + self, node: Node | None, filename: str | Path | None = None + ) -> Path: """ Make sure the node xor filename was provided, and if it's the node, convert it into a canonical filename by exploiting the node's semantic path. @@ -195,6 +197,11 @@ def _parse_filename(self, node: Node | None, filename: str | Path | None = None) f"Both the node ({node.full_label}) and filename ({filename}) were " f"specified for loading -- please only specify one or the other." ) + else: + raise AssertionError( + "This is an unreachable state -- we have covered all four cases of the " + "boolean `is (not) None` square." + ) class PickleStorage(StorageInterface): @@ -204,11 +211,11 @@ class PickleStorage(StorageInterface): def __init__(self, cloudpickle_fallback: bool = True): self.cloudpickle_fallback = cloudpickle_fallback - def _fallback(self, cpf: bool | None): + def _fallback(self, cpf: bool | None) -> bool: return self.cloudpickle_fallback if cpf is None else cpf def _save( - self, node: Node, filename: Path, cloudpickle_fallback: bool | None = None + self, node: Node, filename: Path, /, cloudpickle_fallback: bool | None = None ): if not self._fallback(cloudpickle_fallback) and not node.import_ready: raise TypeNotFoundError( @@ -236,19 +243,22 @@ def _save( if e is not None: raise e - def _load(self, filename: Path, cloudpickle_fallback: bool | None = None) -> Node: + def _load( + self, filename: Path, /, cloudpickle_fallback: bool | None = None + ) -> Node: attacks = [(self._PICKLE, pickle.load)] if self._fallback(cloudpickle_fallback): attacks += [(self._CLOUDPICKLE, cloudpickle.load)] for suffix, load_method in attacks: p = filename.with_suffix(suffix) - if p.exists(): + if p.is_file(): with open(p, "rb") as filehandle: inst = load_method(filehandle) return inst + raise FileNotFoundError(f"Could not load {filename}, no such file found.") - def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): + def _delete(self, filename: Path, /, cloudpickle_fallback: bool | None = None): suffixes = ( [self._PICKLE, self._CLOUDPICKLE] if self._fallback(cloudpickle_fallback) @@ -258,7 +268,7 @@ def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): filename.with_suffix(suffix).unlink(missing_ok=True) def _has_saved_content( - self, filename: Path, cloudpickle_fallback: bool | None = None + self, filename: Path, /, cloudpickle_fallback: bool | None = None ) -> bool: suffixes = ( [self._PICKLE, self._CLOUDPICKLE] diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 8a5ddb29..8b0a707b 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -10,7 +10,8 @@ from bidict import bidict -from pyiron_workflow.io import Inputs, Outputs +from pyiron_workflow.io import Inputs +from pyiron_workflow.mixin.injection import OutputsWithInjection from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -25,6 +26,12 @@ class ParentMostError(TypeError): """ +class NoArgsError(TypeError): + """ + To be raised when *args can't be processed but are received + """ + + class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of @@ -224,7 +231,7 @@ def __init__( self.outputs_map = outputs_map self._inputs = None self._outputs = None - self.automate_execution = automate_execution + self.automate_execution: bool = automate_execution super().__init__( *nodes, @@ -294,7 +301,7 @@ def _build_inputs(self): return self._build_io("inputs", self.inputs_map) @property - def outputs(self) -> Outputs: + def outputs(self) -> OutputsWithInjection: return self._build_outputs() def _build_outputs(self): @@ -304,7 +311,7 @@ def _build_io( self, i_or_o: Literal["inputs", "outputs"], key_map: dict[str, str | None] | None, - ) -> Inputs | Outputs: + ) -> Inputs | OutputsWithInjection: """ Build an IO panel for exposing child node IO to the outside world at the level of the composite node's IO. @@ -320,10 +327,10 @@ def _build_io( (which normally would be exposed) by providing a string-None map. Returns: - (Inputs|Outputs): The populated panel. + (Inputs|OutputsWithInjection): The populated panel. """ key_map = {} if key_map is None else key_map - io = Inputs() if i_or_o == "inputs" else Outputs() + io = Inputs() if i_or_o == "inputs" else OutputsWithInjection() for node in self.children.values(): panel = getattr(node, i_or_o) for channel in panel: @@ -360,12 +367,18 @@ def _before_run( def run( self, + *args, check_readiness: bool = True, **kwargs, ): # Note: Workflows may have neither parents nor siblings, so we don't need to # worry about running their data trees first, fetching their input, nor firing # their `ran` signal, hence the change in signature from Node.run + if len(args) > 0: + raise NoArgsError( + f"{self.__class__} does not know how to process *args on run, but " + f"received {args}" + ) return super().run( run_data_tree=False, diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index 00586444..3f9cb193 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -17,7 +17,7 @@ ) -class Dummy(HasIO): +class Dummy(HasIO[Outputs]): def __init__(self, label: str | None = "has_io"): super().__init__() self._label = label diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index f19032b7..bb7fd5c0 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -10,7 +10,7 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import ParentMostError, Workflow +from pyiron_workflow.workflow import NoArgsError, ParentMostError, Workflow ensure_tests_in_python_path() @@ -258,6 +258,12 @@ def sum_(a, b): return a + b wf.sum = sum_(wf.a, wf.b) + with self.assertRaises( + NoArgsError, + msg="Workflows don't know what to do with raw args, since their input " + "has no intrinsic order", + ): + wf.run(1, 2) wf.run() self.assertEqual( wf.a.outputs.y.value + wf.b.outputs.y.value, From 2b9f55022b691bd3b76ce9935c5e0629ea7ae9e5 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 10:32:51 -0800 Subject: [PATCH 22/43] `mypy` draw (#555) * Write 3-tuple explicitly Signed-off-by: liamhuber * Always hint channels having parents Unlike nodes, the IO and Channel objects it the draw module always wind up having a parent Signed-off-by: liamhuber * Make draw._Channel generic On the underlying workflow channel type, so the data channel can later access its value. Signed-off-by: liamhuber * Ruff fix imports Signed-off-by: liamhuber * Stringify TYPE_CHECKING class uses Oops. Signed-off-by: liamhuber * Silence ruff The imports are indeed used, but only in string form for the sake of the static type checker. Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/draw.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/pyiron_workflow/draw.py b/pyiron_workflow/draw.py index 80e930c2..d6e47931 100644 --- a/pyiron_workflow/draw.py +++ b/pyiron_workflow/draw.py @@ -5,7 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import graphviz from pyiron_snippets.colors import SeabornColors @@ -13,7 +13,13 @@ from pyiron_workflow.channels import NotData if TYPE_CHECKING: - from pyiron_workflow.channels import Channel as WorkflowChannel + from pyiron_workflow.channels import Channel as WorkflowChannel # noqa: F401 + from pyiron_workflow.channels import ( + DataChannel as WorkflowDataChannel, # noqa: F401 + ) + from pyiron_workflow.channels import ( + SignalChannel as WorkflowSignalChannel, # noqa: F401 + ) from pyiron_workflow.io import DataIO, SignalIO from pyiron_workflow.node import Node as WorkflowNode @@ -67,7 +73,11 @@ def _to_hex(rgb: tuple[int, int, int]) -> str: def _to_rgb(hex_: str) -> tuple[int, int, int]: """Hex to RGB color codes; no alpha values.""" hex_ = hex_.lstrip("#") - return tuple(int(hex_[i : i + 2], 16) for i in (0, 2, 4)) + return ( + int(hex_[0 : 0 + 2], 16), + int(hex_[2 : 2 + 2], 16), + int(hex_[4 : 4 + 2], 16), + ) # mypy isn't smart enough to parse this as a 3-tuple from an iterator def blend_colours(color_a, color_b, fraction_a=0.5): @@ -117,14 +127,17 @@ def color(self) -> str: pass -class _Channel(WorkflowGraphvizMap, ABC): +WorkflowChannelType = TypeVar("WorkflowChannelType", bound="WorkflowChannel") + + +class _Channel(WorkflowGraphvizMap, Generic[WorkflowChannelType], ABC): """ An abstract representation for channel objects, which are "nodes" in graphviz parlance. """ - def __init__(self, parent: _IO, channel: WorkflowChannel, local_name: str): - self.channel = channel + def __init__(self, parent: _IO, channel: WorkflowChannelType, local_name: str): + self.channel: WorkflowChannelType = channel self._parent = parent self._name = self.parent.name + local_name self._label = local_name + self._build_label_suffix() @@ -153,7 +166,7 @@ def _build_label_suffix(self): return suffix @property - def parent(self) -> _IO | None: + def parent(self) -> _IO: return self._parent @property @@ -173,7 +186,7 @@ def style(self) -> str: return "filled" -class DataChannel(_Channel): +class DataChannel(_Channel["WorkflowDataChannel"]): @property def color(self) -> str: orange = "#EDB22C" @@ -190,7 +203,7 @@ def style(self) -> str: return "filled" -class SignalChannel(_Channel): +class SignalChannel(_Channel["WorkflowSignalChannel"]): @property def color(self) -> str: blue = "#21BFD8" From eeeb65d134264b3f772f74524800aaac5f218e4d Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 10:55:33 -0800 Subject: [PATCH 23/43] Generic value receiver (#556) * :bug: Re-parent abstract method It belongs with the owners of channels, and these are HasIO, not the IO panels (which are merely dumb containers that give shortcuts to certain functionality) Signed-off-by: liamhuber * Make DataChannel generic on value_receiver Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 11 ++++++----- pyiron_workflow/io.py | 10 +++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 153932fd..310e4df1 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -38,6 +38,7 @@ class ChannelConnectionError(ChannelError): InputType = typing.TypeVar("InputType", bound="InputChannel") OutputType = typing.TypeVar("OutputType", bound="OutputChannel") FlavorType = typing.TypeVar("FlavorType", bound="FlavorChannel") +ReceiverType = typing.TypeVar("ReceiverType", bound="DataChannel") class Channel( @@ -277,7 +278,7 @@ def __bool__(self): NOT_DATA = NotData() -class DataChannel(FlavorChannel["DataChannel"], ABC): +class DataChannel(FlavorChannel["DataChannel"], typing.Generic[ReceiverType], ABC): """ Data channels control the flow of data on the graph. @@ -362,7 +363,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: Self | None = None, + value_receiver: ReceiverType | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -371,7 +372,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver: Self = value_receiver + self.value_receiver: ReceiverType = value_receiver @property def value(self): @@ -526,7 +527,7 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel, InputChannel["OutputData"]): +class InputData(DataChannel["InputData"], InputChannel["OutputData"]): @classmethod def connection_conjugate(cls) -> type[OutputData]: @@ -566,7 +567,7 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel, OutputChannel["InputData"]): +class OutputData(DataChannel["OutputData"], OutputChannel["InputData"]): @classmethod def connection_conjugate(cls) -> type[InputData]: return InputData diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index e5890884..ecef56ce 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -334,14 +334,10 @@ def __init__(self, *args, **kwargs) -> None: def inputs(self) -> Inputs: pass + @property @abstractmethod - def data_input_locked(self) -> bool: - """ - Indicates whether data input channels should consider this owner locked to - change. - """ - # Practically, this gives a well-named interface between HasIO and everything - # to do with run status + def data_input_locked(self): + """Prevents Inputs from updating when True""" @property @abstractmethod From c2269b885ea412801348cd8b24a53301456812c2 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 10:56:49 -0800 Subject: [PATCH 24/43] Hint attribute Signed-off-by: liamhuber --- pyiron_workflow/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 27c033e0..165977ad 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -303,7 +303,7 @@ def __init__( # under-development status -- API may change to be more user-friendly self._do_clean: bool = False # Power-user override for cleaning up temporary # serialized results and empty directories (or not). - self._cached_inputs = None + self._cached_inputs: dict[str, Any] | None = None self._user_data: dict[str, Any] = {} # A place for power-users to bypass node-injection From 1e2ab80c71c361ae82f97a7d1f1421d584b73ae1 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 10:58:21 -0800 Subject: [PATCH 25/43] Don't reuse variable Signed-off-by: liamhuber --- pyiron_workflow/node.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 165977ad..86a3091f 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -830,10 +830,8 @@ def draw( (graphviz.graphs.Digraph): The resulting graph object. """ - - if size is not None: - size = f"{size[0]},{size[1]}" - graph = GraphvizNode(self, depth=depth, rankdir=rankdir, size=size).graph + size_str = f"{size[0]},{size[1]}" if size is not None else None + graph = GraphvizNode(self, depth=depth, rankdir=rankdir, size=size_str).graph if save or view or filename is not None: directory = self.as_path() if directory is None else Path(directory) filename = self.label + "_graph" if filename is None else filename From 309433b1dd78c806e56a1880f1ff7d40006eebc7 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 11:01:16 -0800 Subject: [PATCH 26/43] Only hint available backends Not just any old string Signed-off-by: liamhuber --- pyiron_workflow/node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 86a3091f..139e3ff5 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -894,7 +894,7 @@ def replace_with(self, other: Node | type[Node]): def save( self, - backend: str | StorageInterface = "pickle", + backend: Literal["pickle"] | StorageInterface = "pickle", filename: str | Path | None = None, **kwargs, ): @@ -916,7 +916,7 @@ def save( save.__doc__ += _save_load_warnings - def save_checkpoint(self, backend: str | StorageInterface = "pickle"): + def save_checkpoint(self, backend: Literal["pickle"] | StorageInterface = "pickle"): """ Triggers a save on the parent-most node. @@ -928,7 +928,7 @@ def save_checkpoint(self, backend: str | StorageInterface = "pickle"): def load( self, - backend: str | StorageInterface = "pickle", + backend: Literal["pickle"] | StorageInterface = "pickle", only_requested=False, filename: str | Path | None = None, **kwargs, From 9c46954c14cd2db5cd9d36f22d8afddd166cd56d Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 11:05:42 -0800 Subject: [PATCH 27/43] Cast docstrings to string Signed-off-by: liamhuber --- pyiron_workflow/node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 139e3ff5..881142b2 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from importlib import import_module -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import cloudpickle from pyiron_snippets.colors import SeabornColors @@ -914,7 +914,7 @@ def save( ): selected_backend.save(node=self, filename=filename, **kwargs) - save.__doc__ += _save_load_warnings + save.__doc__ = cast(str, save.__doc__) + _save_load_warnings def save_checkpoint(self, backend: Literal["pickle"] | StorageInterface = "pickle"): """ @@ -973,7 +973,7 @@ def load( ) self.__setstate__(inst.__getstate__()) - load.__doc__ += _save_load_warnings + load.__doc__ = cast(str, load.__doc__) + _save_load_warnings def delete_storage( self, From 70295ec6076c68e774f333f30c6c3386fb855d26 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 11:11:46 -0800 Subject: [PATCH 28/43] Hint NotData class Instead of NOT_DATA instance Signed-off-by: liamhuber --- pyiron_workflow/nodes/transform.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 3ae3218b..3aece6c2 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -15,7 +15,7 @@ from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory -from pyiron_workflow.channels import NOT_DATA +from pyiron_workflow.channels import NotData, NOT_DATA from pyiron_workflow.mixin.preview import builds_class_io from pyiron_workflow.nodes.static_io import StaticNode @@ -56,7 +56,7 @@ def process_run_result(self, run_output: Any | tuple) -> Any | tuple: class ToManyOutputs(Transformer, ABC): _input_name: ClassVar[str] # Mandatory attribute for non-abstract subclasses _input_type_hint: ClassVar[Any] = None - _input_default: ClassVar[Any | NOT_DATA] = NOT_DATA + _input_default: ClassVar[Any | NotData] = NOT_DATA # _build_outputs_preview still required from parent class # Must be commensurate with the dictionary returned by transform_to_output @@ -179,14 +179,14 @@ class InputsToDict(FromManyInputs, ABC): _output_name: ClassVar[str] = "dict" _output_type_hint: ClassVar[Any] = dict _input_specification: ClassVar[ - list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]] + list[str] | dict[str, tuple[Any | None, Any | NotData]] ] def _on_run(self, **inputs_to_value_dict): return inputs_to_value_dict @classmethod - def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NOT_DATA]]: + def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NotData]]: if isinstance(cls._input_specification, list): return {key: (None, NOT_DATA) for key in cls._input_specification} else: @@ -194,7 +194,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NOT_DATA]]: @staticmethod def hash_specification( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], ): """For generating unique subclass names.""" @@ -220,7 +220,7 @@ def hash_specification( @classfactory def inputs_to_dict_factory( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], class_name_suffix: str | None, use_cache: bool = True, /, @@ -241,7 +241,7 @@ def inputs_to_dict_factory( def inputs_to_dict( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], *node_args, class_name_suffix: str | None = None, use_cache: bool = True, From 12458e30e52506d5c2f90a171d21e52cb255c82e Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 11:21:07 -0800 Subject: [PATCH 29/43] Just use classmethod It doesn't combine with property Signed-off-by: liamhuber --- pyiron_workflow/nodes/transform.py | 7 +++---- tests/unit/nodes/test_transform.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 3aece6c2..0ea35ac6 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -347,7 +347,6 @@ class DataclassNode(FromManyInputs, ABC): _output_name: ClassVar[str] = "dataclass" @classmethod - @property def _dataclass_fields(cls): return cls.dataclass.__dataclass_fields__ @@ -357,9 +356,9 @@ def _setup_node(self) -> None: for name, channel in self.inputs.items(): if ( channel.value is NOT_DATA - and self._dataclass_fields[name].default_factory is not MISSING + and self._dataclass_fields()[name].default_factory is not MISSING ): - self.inputs[name] = self._dataclass_fields[name].default_factory() + self.inputs[name] = self._dataclass_fields()[name].default_factory() def _on_run(self, **inputs_to_value_dict): return self.dataclass(**inputs_to_value_dict) @@ -373,7 +372,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: # Make a channel for each field return { name: (f.type, NOT_DATA if f.default is MISSING else f.default) - for name, f in cls._dataclass_fields.items() + for name, f in cls._dataclass_fields().items() } @classmethod diff --git a/tests/unit/nodes/test_transform.py b/tests/unit/nodes/test_transform.py index 569fc799..e9bef9d6 100644 --- a/tests/unit/nodes/test_transform.py +++ b/tests/unit/nodes/test_transform.py @@ -230,7 +230,7 @@ class DecoratedDCLike: prev = n_cls.preview_inputs() key = random.choice(list(prev.keys())) self.assertIs( - n_cls._dataclass_fields[key].type, + n_cls._dataclass_fields()[key].type, prev[key][0], msg="Spot-check input type hints are pulled from dataclass fields", ) @@ -238,7 +238,7 @@ class DecoratedDCLike: prev["necessary"][1], NOT_DATA, msg="Field has no default" ) self.assertEqual( - n_cls._dataclass_fields["with_default"].default, + n_cls._dataclass_fields()["with_default"].default, prev["with_default"][1], msg="Fields with default should get scraped", ) From de767fac71a7e2e05ed5b9be5f676bbb1d9cf05f Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 11:21:38 -0800 Subject: [PATCH 30/43] Ruff fix imports Signed-off-by: liamhuber --- pyiron_workflow/nodes/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 0ea35ac6..ed38302c 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -15,7 +15,7 @@ from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory -from pyiron_workflow.channels import NotData, NOT_DATA +from pyiron_workflow.channels import NOT_DATA, NotData from pyiron_workflow.mixin.preview import builds_class_io from pyiron_workflow.nodes.static_io import StaticNode From ba29c55f903f102afa296bda183913f5e49e4e7b Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 12:19:28 -0800 Subject: [PATCH 31/43] Remove preview helper method (#557) * Remove preview helper method The potential usage is limited to transform and for-loop modules, and the extra layer of misdirection does not feel worth the very minimal reduction in code duplication Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/preview.py | 18 +----------------- pyiron_workflow/nodes/transform.py | 12 +++++++----- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index ec74f3e7..7ae1a1c1 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -15,7 +15,7 @@ import inspect from abc import ABC, abstractmethod from collections.abc import Callable -from functools import lru_cache, wraps +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -82,22 +82,6 @@ def preview_io(cls) -> DotDict[str, dict]: ) -def builds_class_io(subclass_factory: Callable[..., type[HasIOPreview]]): - """ - A decorator for factories producing subclasses of `HasIOPreview` to invoke - :meth:`preview_io` after the class is created, thus ensuring the IO has been - constructed at the class level. - """ - - @wraps(subclass_factory) - def wrapped(*args, **kwargs): - node_class = subclass_factory(*args, **kwargs) - node_class.preview_io() - return node_class - - return wrapped - - class ScrapesIO(HasIOPreview, ABC): """ A mixin class for scraping IO channel information from a specific class method's diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index ed38302c..9d0780d2 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -16,7 +16,6 @@ from pyiron_snippets.factory import classfactory from pyiron_workflow.channels import NOT_DATA, NotData -from pyiron_workflow.mixin.preview import builds_class_io from pyiron_workflow.nodes.static_io import StaticNode @@ -107,7 +106,6 @@ def _build_outputs_preview(cls) -> dict[str, Any]: return {f"item_{i}": None for i in range(cls._length)} -@builds_class_io @classfactory def inputs_to_list_factory(n: int, use_cache: bool = True, /) -> type[InputsToList]: return ( @@ -137,10 +135,11 @@ def inputs_to_list(n: int, /, *node_args, use_cache: bool = True, **node_kwargs) InputsToList: An instance of the dynamically created :class:`InputsToList` subclass. """ - return inputs_to_list_factory(n, use_cache)(*node_args, **node_kwargs) + cls = inputs_to_list_factory(n, use_cache) + cls.preview_io() + return cls(*node_args, **node_kwargs) -@builds_class_io @classfactory def list_to_outputs_factory(n: int, use_cache: bool = True, /) -> type[ListToOutputs]: return ( @@ -172,7 +171,10 @@ def list_to_outputs( ListToOutputs: An instance of the dynamically created :class:`ListToOutputs` subclass. """ - return list_to_outputs_factory(n, use_cache)(*node_args, **node_kwargs) + + cls = list_to_outputs_factory(n, use_cache) + cls.preview_io() + return cls(*node_args, **node_kwargs) class InputsToDict(FromManyInputs, ABC): From 4f05f6e1a4108b48810ac5ff21de777324d40114 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 12:21:28 -0800 Subject: [PATCH 32/43] Relax Node.emitting_channels hint To allow a tuple of variable length Signed-off-by: liamhuber --- pyiron_workflow/node.py | 2 +- pyiron_workflow/nodes/standard.py | 2 +- tests/integration/test_workflow.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 881142b2..07a2d7d3 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -707,7 +707,7 @@ def _outputs_to_run_return(self): return DotDict(self.outputs.to_value_dict()) @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.failed: return (self.signals.output.failed,) else: diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index e9b4c683..6635a4c9 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -63,7 +63,7 @@ def node_function(condition): return truth @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.outputs.truth.value is NOT_DATA: return super().emitting_channels elif self.outputs.truth.value: diff --git a/tests/integration/test_workflow.py b/tests/integration/test_workflow.py index f1686be1..09425401 100644 --- a/tests/integration/test_workflow.py +++ b/tests/integration/test_workflow.py @@ -68,7 +68,7 @@ def node_function(value, limit=10): return value_gt_limit @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.outputs.value_gt_limit.value: print(f"{self.inputs.value.value} > {self.inputs.limit.value}") return (*super().emitting_channels, self.signals.output.true) From f6ea024dc6f3e08bd410656da386202d3e73d244 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 12:24:46 -0800 Subject: [PATCH 33/43] Account for __doc__ possibly being None Signed-off-by: liamhuber --- pyiron_workflow/nodes/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 9d0780d2..aed09074 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -379,7 +379,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: @classmethod def _extra_info(cls) -> str: - return cls.dataclass.__doc__ + return cls.dataclass.__doc__ or "" @classfactory From 2240e986020f00bcaa28018f7ababeaff36c1ed3 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 12:29:21 -0800 Subject: [PATCH 34/43] Be more specific in return hint Signed-off-by: liamhuber --- pyiron_workflow/nodes/composite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 1689da92..eabaa8e9 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -277,7 +277,7 @@ def _get_state_from_remote_other(self, other_self): state.pop("_parent") # Got overridden to None for __getstate__, so keep local return state - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """ Disconnect all `signals.input.run` connections on all child nodes. From 44d68b2b474eb1527432e5cc0f7b04d98f6d1c50 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 12:51:10 -0800 Subject: [PATCH 35/43] Change return on Composite.remove_child (#558) * Change return on Composite.remove_child To match return in parent class. Disconnections were only ever used in the test case, and users are always free to disconnect and _then_ remove if they want to capture the broken connections explicitly. Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Fix docstring types Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/nodes/composite.py | 11 +++++------ tests/unit/nodes/test_composite.py | 19 ++++--------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index eabaa8e9..a8c29890 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from pyiron_workflow.channels import ( - Channel, InputSignal, OutputSignal, ) @@ -282,7 +281,7 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: Disconnect all `signals.input.run` connections on all child nodes. Returns: - list[tuple[Channel, Channel]]: Any disconnected pairs. + list[tuple[InputSignal, OutputSignal]]: Any disconnected pairs. """ disconnected_pairs = [] for node in self.children.values(): @@ -307,7 +306,7 @@ def add_child( self._cached_inputs = None # Reset cache after graph change return super().add_child(child, label=label, strict_naming=strict_naming) - def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]: + def remove_child(self, child: Node | str) -> Node: """ Remove a child from the :attr:`children` collection, disconnecting it and setting its :attr:`parent` to None. @@ -316,14 +315,14 @@ def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]: child (Node|str): The child (or its label) to remove. Returns: - (list[tuple[Channel, Channel]]): Any connections that node had. + (Node): The (now disconnected and de-parented) (former) child node. """ child = super().remove_child(child) - disconnected = child.disconnect() + child.disconnect() if child in self.starting_nodes: self.starting_nodes.remove(child) self._cached_inputs = None # Reset cache after graph change - return disconnected + return child def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] diff --git a/tests/unit/nodes/test_composite.py b/tests/unit/nodes/test_composite.py index af4ed8ee..c4487924 100644 --- a/tests/unit/nodes/test_composite.py +++ b/tests/unit/nodes/test_composite.py @@ -133,29 +133,18 @@ def test_node_removal(self): # Connect it inside the composite self.comp.foo.inputs.x = self.comp.owned.outputs.y - disconnected = self.comp.remove_child(node) + self.comp.remove_child(node) self.assertIsNone(node.parent, msg="Removal should de-parent") self.assertFalse(node.connected, msg="Removal should disconnect") - self.assertListEqual( - [(node.inputs.x, self.comp.owned.outputs.y)], - disconnected, - msg="Removal should return destroyed connections", - ) self.assertListEqual( self.comp.starting_nodes, [], msg="Removal should also remove from starting nodes", ) - - node_owned = self.comp.owned - disconnections = self.comp.remove_child(node_owned.label) - self.assertEqual( - node_owned.parent, - None, - msg="Should be able to remove nodes by label as well as by object", - ) self.assertListEqual( - [], disconnections, msg="node1 should have no connections left" + [], + self.comp.owned.connections, + msg="Remaining node should have no connections left", ) def test_label_uniqueness(self): From b09d0efeffd78f03b5f0c4404833a92b011c88c5 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 13:07:28 -0800 Subject: [PATCH 36/43] IO maps (#559) * Refactor Workflow map setter Signed-off-by: liamhuber * Remove return hint The method is in-place Signed-off-by: liamhuber * Move the None check around Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/workflow.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 8b0a707b..24c2032b 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -225,10 +225,8 @@ def __init__( automate_execution: bool = True, **kwargs, ): - self._inputs_map = None - self._outputs_map = None - self.inputs_map = inputs_map - self.outputs_map = outputs_map + self._inputs_map = self._sanitize_map(inputs_map) + self._outputs_map = self._sanitize_map(outputs_map) self._inputs = None self._outputs = None self.automate_execution: bool = automate_execution @@ -264,34 +262,36 @@ def _after_node_setup( @property def inputs_map(self) -> bidict | None: - self._deduplicate_nones(self._inputs_map) + if self._inputs_map is not None: + self._deduplicate_nones(self._inputs_map) return self._inputs_map @inputs_map.setter def inputs_map(self, new_map: dict | bidict | None): - self._deduplicate_nones(new_map) - if new_map is not None: - new_map = bidict(new_map) - self._inputs_map = new_map + self._inputs_map = self._sanitize_map(new_map) @property def outputs_map(self) -> bidict | None: - self._deduplicate_nones(self._outputs_map) + if self._outputs_map is not None: + self._deduplicate_nones(self._outputs_map) return self._outputs_map @outputs_map.setter def outputs_map(self, new_map: dict | bidict | None): - self._deduplicate_nones(new_map) + self._outputs_map = self._sanitize_map(new_map) + + def _sanitize_map(self, new_map: dict | bidict | None) -> bidict | None: if new_map is not None: + if isinstance(new_map, dict): + self._deduplicate_nones(new_map) new_map = bidict(new_map) - self._outputs_map = new_map + return new_map @staticmethod - def _deduplicate_nones(some_map: dict | bidict | None) -> dict | bidict | None: - if some_map is not None: - for k, v in some_map.items(): - if v is None: - some_map[k] = (None, f"{k} disabled") + def _deduplicate_nones(some_map: dict | bidict): + for k, v in some_map.items(): + if v is None: + some_map[k] = (None, f"{k} disabled") @property def inputs(self) -> Inputs: From e61a9db04a84eaab7e3175d4b5c49912433a8207 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 13:10:00 -0800 Subject: [PATCH 37/43] Reverse instance check Signed-off-by: liamhuber --- pyiron_workflow/workflow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 24c2032b..3094683c 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -336,8 +336,9 @@ def _build_io( for channel in panel: try: io_panel_key = key_map[channel.scoped_label] - if not isinstance(io_panel_key, tuple): - # Tuples indicate that the channel has been deactivated + if isinstance(io_panel_key, str): + # Otherwise it's a None-str tuple, indicaticating that the + # channel has been deactivated # This is a necessary misdirection to keep the bidict working, # as we can't simply map _multiple_ keys to `None` io[io_panel_key] = channel From 89078427b39b152f023571903d06b8930a4230d1 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 13:23:26 -0800 Subject: [PATCH 38/43] Return both nodes on replacement (#560) * Return both nodes on replacement Instead of only returning the replaced node. Signed-off-by: liamhuber * :bug: only use new variable Missed a spot. Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/nodes/composite.py | 34 ++++++++++++++++++++---------- pyiron_workflow/nodes/macro.py | 2 +- pyiron_workflow/workflow.py | 10 +++++---- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index a8c29890..29a0975d 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -326,7 +326,7 @@ def remove_child(self, child: Node | str) -> Node: def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: + ) -> tuple[Node, Node]: """ Replaces a node currently owned with a new node instance. The replacement must not belong to any other parent or have any connections. @@ -348,7 +348,7 @@ def replace_child( and simply gets instantiated.) Returns: - (Node): The node that got removed + (Node, Node): The node that got removed and the new one that replaced it. """ if isinstance(owned_node, str): owned_node = self.children[owned_node] @@ -367,15 +367,18 @@ def replace_child( ) if replacement.connected: raise ValueError("Replacement node must not have any connections") + replacement_node = replacement elif issubclass(replacement, Node): - replacement = replacement(label=owned_node.label) + replacement_node = replacement(label=owned_node.label) else: raise TypeError( f"Expected replacement node to be a node instance or node subclass, but " f"got {replacement}" ) - replacement.copy_io(owned_node) # If the replacement is incompatible, we'll + replacement_node.copy_io( + owned_node + ) # If the replacement is incompatible, we'll # fail here before we've changed the parent at all. Since the replacement was # first guaranteed to be an unconnected orphan, there is not yet any permanent # damage @@ -383,28 +386,37 @@ def replace_child( # In case the replaced node interfaces with the composite's IO, catch value # links inbound_links = [ - (sending_channel, replacement.inputs[sending_channel.value_receiver.label]) + ( + sending_channel, + replacement_node.inputs[sending_channel.value_receiver.label], + ) for sending_channel in self.inputs if sending_channel.value_receiver in owned_node.inputs ] outbound_links = [ - (replacement.outputs[sending_channel.label], sending_channel.value_receiver) + ( + replacement_node.outputs[sending_channel.label], + sending_channel.value_receiver, + ) for sending_channel in owned_node.outputs if sending_channel.value_receiver in self.outputs ] self.remove_child(owned_node) - replacement.label, owned_node.label = owned_node.label, replacement.label - self.add_child(replacement) + replacement_node.label, owned_node.label = ( + owned_node.label, + replacement_node.label, + ) + self.add_child(replacement_node) if is_starting_node: - self.starting_nodes.append(replacement) + self.starting_nodes.append(replacement_node) for sending_channel, receiving_channel in inbound_links + outbound_links: sending_channel.value_receiver = receiving_channel # Clear caches self._cached_inputs = None - replacement._cached_inputs = None + replacement_node._cached_inputs = None - return owned_node + return owned_node, replacement_node def executor_shutdown(self, wait=True, *, cancel_futures=False): """ diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index d7a3fe53..ced1e533 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -219,7 +219,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC): >>> # With the replace method >>> # (replacement target can be specified by label or instance, >>> # the replacing node can be specified by instance or class) - >>> replaced = adds_six_macro.replace_child(adds_six_macro.one, add_two()) + >>> replaced, _ = adds_six_macro.replace_child(adds_six_macro.one, add_two()) >>> # With the replace_with method >>> adds_six_macro.two.replace_with(add_two()) >>> # And by assignment of a compatible class to an occupied node label diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 3094683c..aa39caff 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -497,8 +497,10 @@ def _owned_io_panels(self) -> list[IO]: def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: - super().replace_child(owned_node=owned_node, replacement=replacement) + ) -> tuple[Node, Node]: + replaced, replacement_node = super().replace_child( + owned_node=owned_node, replacement=replacement + ) # Finally, make sure the IO is constructible with this new node, which will # catch things like incompatible IO maps @@ -509,11 +511,11 @@ def replace_child( except Exception as e: # If IO can't be successfully rebuilt using this node, revert changes and # raise the exception - self.replace_child(replacement, owned_node) # Guaranteed to work since + self.replace_child(replacement_node, replaced) # Guaranteed to work since # replacement in the other direction was already a success raise e - return owned_node + return replaced, replacement_node @property def parent(self) -> None: From fcb341ae3ea004572728061b81a51a91f58807b3 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 13:52:54 -0800 Subject: [PATCH 39/43] Hint graph creator Incompletely for the creator, but I'm having trouble getting mypy happy with hinting args and kwargs here. Signed-off-by: liamhuber --- pyiron_workflow/nodes/macro.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index ced1e533..fde7ab66 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -273,7 +273,9 @@ def _setup_node(self) -> None: @staticmethod @abstractmethod - def graph_creator(self, *args, **kwargs) -> Callable: + def graph_creator( + self: Macro, *args, **kwargs + ) -> HasChannel | tuple[HasChannel, ...] | None: """Build the graph the node will run.""" @classmethod @@ -538,7 +540,7 @@ def decorator(graph_creator): def macro_node( - graph_creator: Callable, + graph_creator: Callable[..., HasChannel | tuple[HasChannel, ...] | None], *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, From e3d1ada5efb6af07b01ae280c81387c4dc383f09 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 13:58:37 -0800 Subject: [PATCH 40/43] Don't reuse variable Locally mypy doesn't care about this, but somehow on the CI it whines, even though the mypy version is allegedly the same. Signed-off-by: liamhuber --- pyiron_workflow/mixin/semantics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index 1ecfd8fb..5b4358e1 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -368,18 +368,19 @@ def _add_suffix_to_label(self, label: str) -> str: def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): - child = self.children.pop(child) + child_instance = self.children.pop(child) elif isinstance(child, self.child_type()): self.children.inv.pop(child) + child_instance = child else: raise TypeError( f"{self.label} expected to remove a child of type str or " f"{self.child_type()} but got {child}" ) - child.parent = None + child_instance.parent = None - return child + return child_instance def __getstate__(self): state = super().__getstate__() From 1b3baaeadf476479b1815bc3aa05dfe6fca49601 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 17 Jan 2025 14:03:05 -0800 Subject: [PATCH 41/43] Don't reuse variable here either Signed-off-by: liamhuber --- pyiron_workflow/nodes/composite.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 29a0975d..38f988aa 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -350,13 +350,14 @@ def replace_child( Returns: (Node, Node): The node that got removed and the new one that replaced it. """ - if isinstance(owned_node, str): - owned_node = self.children[owned_node] + owned_node_instance = ( + self.children[owned_node] if isinstance(owned_node, str) else owned_node + ) - if owned_node.parent is not self: + if owned_node_instance.parent is not self: raise ValueError( f"The node being replaced should be a child of this composite, but " - f"another parent was found: {owned_node.parent}" + f"another parent was found: {owned_node_instance.parent}" ) if isinstance(replacement, Node): @@ -369,7 +370,7 @@ def replace_child( raise ValueError("Replacement node must not have any connections") replacement_node = replacement elif issubclass(replacement, Node): - replacement_node = replacement(label=owned_node.label) + replacement_node = replacement(label=owned_node_instance.label) else: raise TypeError( f"Expected replacement node to be a node instance or node subclass, but " @@ -377,12 +378,12 @@ def replace_child( ) replacement_node.copy_io( - owned_node + owned_node_instance ) # If the replacement is incompatible, we'll # fail here before we've changed the parent at all. Since the replacement was # first guaranteed to be an unconnected orphan, there is not yet any permanent # damage - is_starting_node = owned_node in self.starting_nodes + is_starting_node = owned_node_instance in self.starting_nodes # In case the replaced node interfaces with the composite's IO, catch value # links inbound_links = [ @@ -391,19 +392,19 @@ def replace_child( replacement_node.inputs[sending_channel.value_receiver.label], ) for sending_channel in self.inputs - if sending_channel.value_receiver in owned_node.inputs + if sending_channel.value_receiver in owned_node_instance.inputs ] outbound_links = [ ( replacement_node.outputs[sending_channel.label], sending_channel.value_receiver, ) - for sending_channel in owned_node.outputs + for sending_channel in owned_node_instance.outputs if sending_channel.value_receiver in self.outputs ] - self.remove_child(owned_node) - replacement_node.label, owned_node.label = ( - owned_node.label, + self.remove_child(owned_node_instance) + replacement_node.label, owned_node_instance.label = ( + owned_node_instance.label, replacement_node.label, ) self.add_child(replacement_node) @@ -416,7 +417,7 @@ def replace_child( self._cached_inputs = None replacement_node._cached_inputs = None - return owned_node, replacement_node + return owned_node_instance, replacement_node def executor_shutdown(self, wait=True, *, cancel_futures=False): """ From 3d1143e98bd1ef82b15c8f5829a368424cb325af Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 16:26:35 -0800 Subject: [PATCH 42/43] `mypy` for_loop (#561) * Make class "property" a plain method Signed-off-by: liamhuber * Refactor to non-None classvar Signed-off-by: liamhuber * Explicitly cast to tuple For the sake of the name generator Signed-off-by: liamhuber * Silence mypy It is upset about the hinting `list[hint]`, because `hint` is variable. We would be able to verify that it is a type at static analysis time, but since it comes from the body node class -- which is unknown until runtime -- it is impossible to say _which_ type. Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/nodes/for_loop.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index af23faba..6d1d4150 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -154,6 +154,7 @@ class For(Composite, StaticNode, ABC): _iter_on: ClassVar[tuple[str, ...]] = () _zip_on: ClassVar[tuple[str, ...]] = () _output_as_dataframe: ClassVar[bool] = True + _output_column_map: ClassVar[dict[str, str]] = {} def __init_subclass__(cls, output_column_map=None, **kwargs): super().__init_subclass__(**kwargs) @@ -182,18 +183,16 @@ def __init_subclass__(cls, output_column_map=None, **kwargs): f"{cls._body_node_class.__name__} has no such outputs." ) - cls._output_column_map = output_column_map + cls._output_column_map = {} if output_column_map is None else output_column_map @classmethod - @property @lru_cache(maxsize=1) def output_column_map(cls) -> dict[str, str]: """ How to transform body node output labels to dataframe column names. """ map_ = {k: k for k in cls._body_node_class.preview_outputs()} - overrides = {} if cls._output_column_map is None else cls._output_column_map - for body_label, column_name in overrides.items(): + for body_label, column_name in cls._output_column_map.items(): map_[body_label] = column_name return map_ @@ -311,7 +310,7 @@ def _collect_output_as_dataframe(self, iter_maps): row_collector.inputs[label] = self.children[label][i] for label, body_out in self[self._body_name(n)].outputs.items(): - row_collector.inputs[self.output_column_map[label]] = body_out + row_collector.inputs[self.output_column_map()[label]] = body_out self.dataframe.inputs[f"row_{n}"] = row_collector @@ -324,7 +323,7 @@ def _build_row_collector_node(self, row_number) -> InputsToDict: # Outputs row_specification.update( { - self.output_column_map[key]: (hint, NOT_DATA) + self.output_column_map()[key]: (hint, NOT_DATA) for key, hint in self._body_node_class.preview_outputs().items() } ) @@ -339,7 +338,7 @@ def column_collector_name(s: str): return f"column_collector_{s}" for label, hint in self._body_node_class.preview_outputs().items(): - mapped_label = self.output_column_map[label] + mapped_label = self.output_column_map()[label] column_collector = inputs_to_list( n_rows, label=column_collector_name(mapped_label), @@ -377,7 +376,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: for label, (hint, default) in cls._body_node_class.preview_inputs().items(): # TODO: Leverage hint and default, listing if it's looped on if label in cls._zip_on + cls._iter_on: - hint = list if hint is None else list[hint] + hint = list if hint is None else list[hint] # type: ignore[valid-type] default = NOT_DATA # TODO: Figure out a generator pattern to get lists preview[label] = (hint, default) return preview @@ -393,11 +392,11 @@ def _build_outputs_preview(cls) -> dict[str, Any]: _default, ) in cls._body_node_class.preview_inputs().items(): if label in cls._zip_on + cls._iter_on: - hint = list if hint is None else list[hint] + hint = list if hint is None else list[hint] # type: ignore[valid-type] preview[label] = hint for label, hint in cls._body_node_class.preview_outputs().items(): - preview[cls.output_column_map[label]] = ( - list if hint is None else list[hint] + preview[cls.output_column_map()[label]] = ( + list if hint is None else list[hint] # type: ignore[valid-type] ) return preview @@ -648,6 +647,8 @@ def for_node( Index(['a', 'b', 'c', 'd', 'out_a', 'out_b', 'out_c', 'out_d', 'e'], dtype='object') """ + iter_on = (iter_on,) if isinstance(iter_on, str) else iter_on + zip_on = (zip_on,) if isinstance(zip_on, str) else zip_on for_node_factory.clear( _for_node_class_name(body_node_class, iter_on, zip_on, output_as_dataframe) ) From e36409b0e0180077a2642bfbba716a1cf75b765b Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 17 Jan 2025 17:14:32 -0800 Subject: [PATCH 43/43] `mypy` finishing touches (#562) * Make class "property" a plain method Signed-off-by: liamhuber * Refactor to non-None classvar Signed-off-by: liamhuber * Explicitly cast to tuple For the sake of the name generator Signed-off-by: liamhuber * Silence mypy It is upset about the hinting `list[hint]`, because `hint` is variable. We would be able to verify that it is a type at static analysis time, but since it comes from the body node class -- which is unknown until runtime -- it is impossible to say _which_ type. Signed-off-by: liamhuber * Uniformly give and ignore classfactory hints At a minimum, getting mypy to parse these correctly requires more rigorous hinting in pyiron_snippets.factory. But actually, since the classfactory allows the parent class to be specified with _multiple bases_, I'm not even 100% sure we'd ever be able to get a single type variable that could do the trick universally. In any case, for now kick the can don't the road and always hint what you know is true, then tell mypy to not worry about it. Signed-off-by: liamhuber * Add some hints to preview Albeit pretty relaxed ones Signed-off-by: liamhuber * Add a return hint to Runnable.__init__ To get mypy to parse the body of the function Signed-off-by: liamhuber * Break loop into a method Mypy didn't like parsing the zip variable when it could be inputs or outputs (even though both inherit from the relevant DataIO in this case), but using a separate method is functionally equivalent and mypy can get a better grasp of the type values. Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/io.py | 55 +++++++++++++++++------------- pyiron_workflow/mixin/preview.py | 6 ++-- pyiron_workflow/mixin/run.py | 2 +- pyiron_workflow/nodes/for_loop.py | 4 +-- pyiron_workflow/nodes/function.py | 4 +-- pyiron_workflow/nodes/macro.py | 4 +-- pyiron_workflow/nodes/transform.py | 10 +++--- 7 files changed, 47 insertions(+), 38 deletions(-) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index ecef56ce..4d983996 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -572,30 +572,39 @@ def _copy_values( list[tuple[Channel, Any]]: A list of tuples giving channels whose value has been updated and what it used to be (for reverting changes). """ + # Leverage a separate function because mypy has trouble parsing types + # if we loop over inputs and outputs at the same time + return self._copy_panel( + other, self.inputs, other.inputs, fail_hard=fail_hard + ) + self._copy_panel(other, self.outputs, other.outputs, fail_hard=fail_hard) + + def _copy_panel( + self, + other: HasIO, + my_panel: DataIO, + other_panel: DataIO, + fail_hard: bool = False, + ) -> list[tuple[DataChannel, Any]]: old_values = [] - for my_panel, other_panel in [ - (self.inputs, other.inputs), - (self.outputs, other.outputs), - ]: - for key, to_copy in other_panel.items(): - if to_copy.value is not NOT_DATA: - try: - old_value = my_panel[key].value - my_panel[key].value = to_copy.value # Gets hint-checked - old_values.append((my_panel[key], old_value)) - except Exception as e: - if fail_hard: - # If you run into trouble, unwind what you've done - for channel, value in old_values: - channel.value = value - raise ValueCopyError( - f"{self.label} could not copy values from " - f"{other.label} due to the channel {key} on " - f"{other_panel.__class__.__name__}, which holds value " - f"{to_copy.value}" - ) from e - else: - continue + for key, to_copy in other_panel.items(): + if to_copy.value is not NOT_DATA: + try: + old_value = my_panel[key].value + my_panel[key].value = to_copy.value # Gets hint-checked + old_values.append((my_panel[key], old_value)) + except Exception as e: + if fail_hard: + # If you run into trouble, unwind what you've done + for channel, value in old_values: + channel.value = value + raise ValueCopyError( + f"{self.label} could not copy values from " + f"{other.label} due to the channel {key} on " + f"{other_panel.__class__.__name__}, which holds value " + f"{to_copy.value}" + ) from e + else: + continue return old_values @property diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 7ae1a1c1..97695bdb 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -76,7 +76,7 @@ def preview_outputs(cls) -> dict[str, Any]: return cls._build_outputs_preview() @classmethod - def preview_io(cls) -> DotDict[str, dict]: + def preview_io(cls) -> DotDict[str, dict[str, Any | tuple[Any, Any]]]: return DotDict( {"inputs": cls.preview_inputs(), "outputs": cls.preview_outputs()} ) @@ -124,7 +124,7 @@ def _io_defining_function(cls) -> Callable: ) @classmethod - def _build_inputs_preview(cls): + def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: type_hints = cls._get_type_hints() scraped: dict[str, tuple[Any, Any]] = {} for i, (label, value) in enumerate(cls._get_input_args().items()): @@ -152,7 +152,7 @@ def _build_inputs_preview(cls): return scraped @classmethod - def _build_outputs_preview(cls): + def _build_outputs_preview(cls) -> dict[str, Any]: if cls._validate_output_labels: cls._validate() # Validate output on first call diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index ac7aae86..aa40bd67 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -50,7 +50,7 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): new keyword arguments. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.running: bool = False self.failed: bool = False diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index 6d1d4150..bc408821 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -509,7 +509,7 @@ def for_node_factory( output_column_map: dict | None = None, use_cache: bool = True, /, -): +) -> type[For]: combined_docstring = ( "For node docstring:\n" + (For.__doc__ if For.__doc__ is not None else "") @@ -520,7 +520,7 @@ def for_node_factory( iter_on = (iter_on,) if isinstance(iter_on, str) else iter_on zip_on = (zip_on,) if isinstance(zip_on, str) else zip_on - return ( + return ( # type: ignore[return-value] _for_node_class_name(body_node_class, iter_on, zip_on, output_as_dataframe), (For,), { diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index 8000a6d9..877f3b28 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -357,7 +357,7 @@ def function_node_factory( use_cache: bool = True, /, *output_labels, -): +) -> type[Function]: """ Create a new :class:`Function` node class based on the given node function. This function gets executed on each :meth:`run` of the resulting function. @@ -373,7 +373,7 @@ def function_node_factory( Returns: type[Node]: A new node class. """ - return ( + return ( # type: ignore[return-value] node_function.__name__, (Function,), # Define parentage { diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index fde7ab66..ec09f8e8 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -475,7 +475,7 @@ def macro_node_factory( use_cache: bool = True, /, *output_labels: str, -): +) -> type[Macro]: """ Create a new :class:`Macro` subclass using the given graph creator function. @@ -491,7 +491,7 @@ def macro_node_factory( Returns: type[Macro]: A new :class:`Macro` subclass. """ - return ( + return ( # type: ignore[return-value] graph_creator.__name__, (Macro,), # Define parentage { diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index aed09074..6a4371be 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -108,7 +108,7 @@ def _build_outputs_preview(cls) -> dict[str, Any]: @classfactory def inputs_to_list_factory(n: int, use_cache: bool = True, /) -> type[InputsToList]: - return ( + return ( # type: ignore[return-value] f"{InputsToList.__name__}{n}", (InputsToList,), { @@ -142,7 +142,7 @@ def inputs_to_list(n: int, /, *node_args, use_cache: bool = True, **node_kwargs) @classfactory def list_to_outputs_factory(n: int, use_cache: bool = True, /) -> type[ListToOutputs]: - return ( + return ( # type: ignore[return-value] f"{ListToOutputs.__name__}{n}", (ListToOutputs,), { @@ -231,7 +231,7 @@ def inputs_to_dict_factory( class_name_suffix = str( InputsToDict.hash_specification(input_specification) ).replace("-", "m") - return ( + return ( # type: ignore[return-value] f"{InputsToDict.__name__}{class_name_suffix}", (InputsToDict,), { @@ -307,7 +307,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: def inputs_to_dataframe_factory( n: int, use_cache: bool = True, / ) -> type[InputsToDataframe]: - return ( + return ( # type: ignore[return-value] f"{InputsToDataframe.__name__}{n}", (InputsToDataframe,), { @@ -403,7 +403,7 @@ def dataclass_node_factory( # Composition is preferable over inheritance, but we want inheritance to be possible module, qualname = dataclass.__module__, dataclass.__qualname__ dataclass.__qualname__ += ".dataclass" # So output type hints know where to find it - return ( + return ( # type: ignore[return-value] dataclass.__name__, (DataclassNode,), {