diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index 9178f1da..6418e835 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -19,3 +19,34 @@ jobs: alternate-tests-env-files: .ci_support/lower_bound.yml alternate-tests-python-version: '3.10' alternate-tests-dir: tests/unit + + mypy: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + architecture: x64 + - name: Checkout + uses: actions/checkout@v4 + - name: Install mypy + run: pip install mypy + - name: Test + run: mypy --ignore-missing-imports ${{ github.event.repository.name }} + + ruff-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check + + ruff-sort-imports: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check --select I --fix --diff \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 68c0ec0d..00000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Ruff -on: [ push, pull_request ] -jobs: - ruff-check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check - ruff-sort-imports: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check --select I --fix --diff diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 64084eb6..310e4df1 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -14,6 +14,7 @@ from pyiron_snippets.singleton import Singleton +from pyiron_workflow.compatibility import Self from pyiron_workflow.mixin.display_state import HasStateDisplay from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel from pyiron_workflow.type_hinting import ( @@ -25,11 +26,24 @@ from pyiron_workflow.io import HasIO -class ChannelConnectionError(Exception): +class ChannelError(Exception): pass -class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): +class ChannelConnectionError(ChannelError): + pass + + +ConjugateType = typing.TypeVar("ConjugateType", bound="Channel") +InputType = typing.TypeVar("InputType", bound="InputChannel") +OutputType = typing.TypeVar("OutputType", bound="OutputChannel") +FlavorType = typing.TypeVar("FlavorType", bound="FlavorChannel") +ReceiverType = typing.TypeVar("ReceiverType", bound="DataChannel") + + +class Channel( + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConjugateType], ABC +): """ Channels facilitate the flow of information (data or control signals) into and out of :class:`HasIO` objects (namely nodes). @@ -37,12 +51,9 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): They must have an identifier (`label: str`) and belong to an `owner: pyiron_workflow.io.HasIO`. - Non-abstract channel classes should come in input/output pairs and specify the - a necessary ancestor for instances they can connect to - (`connection_partner_type: type[Channel]`). Channels may form (:meth:`connect`/:meth:`disconnect`) and store - (:attr:`connections: list[Channel]`) connections with other channels. + (:attr:`connections`) connections with other channels. This connection information is reflexive, and is duplicated to be stored on _both_ channels in the form of a reference to their counterpart in the connection. @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirement is that the connecting channels form a - "conjugate pair" of classes, i.e. they are children of each other's partner class - (:attr:`connection_partner_type: type[Channel]`) -- input/output connects to - output/input. + Child classes must specify a conjugate class in order to enforce connection + conjugate pairs which have the same "flavor" (e.g. "data" or "signal"), and + opposite "direction" ("input" vs "output"). And they must define a string + representation. Iterating over channels yields their connections. @@ -80,22 +91,18 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[Channel] = [] - - @property - def label(self) -> str: - return self._label + self.connections: list[ConjugateType] = [] @abstractmethod def __str__(self): pass - @property + @classmethod @abstractmethod - def connection_partner_type(self) -> type[Channel]: + def connection_conjugate(cls) -> type[ConjugateType]: """ - Input and output class pairs must specify a parent class for their valid - connection partners. + The class forming a conjugate pair with this channel class -- i.e. the same + "flavor" of channel, but opposite in I/O. """ @property @@ -108,21 +115,12 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - def _valid_connection(self, other: Channel) -> bool: - """ - Logic for determining if a connection is valid. - - Connections only allowed to instances with the right parent type -- i.e. - connection pairs should be an input/output. - """ - return isinstance(other, self.connection_partner_type) - - def connect(self, *others: Channel) -> None: + def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :attr:`connection_partner_type`. + :meth:`connection_conjugate()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -139,30 +137,40 @@ def connect(self, *others: Channel) -> None: for other in others: if other in self.connections: continue - elif self._valid_connection(other): - # Prepend new connections - # so that connection searches run newest to oldest - self.connections.insert(0, other) - other.connections.insert(0, self) - else: - if isinstance(other, self.connection_partner_type): + elif isinstance(other, self.connection_conjugate()): + if self._valid_connection(other): + # Prepend new connections + # so that connection searches run newest to oldest + self.connections.insert(0, other) + other.connections.insert(0, self) + else: raise ChannelConnectionError( - f"The channel {other.full_label} ({other.__class__.__name__}" - f") has the correct type " - f"({self.connection_partner_type.__name__}) to connect with " - f"{self.full_label} ({self.__class__.__name__}), but is not " - f"a valid connection. Please check type hints, etc." - f"{other.full_label}.type_hint = {other.type_hint}; " - f"{self.full_label}.type_hint = {self.type_hint}" + self._connection_conjugate_failure_message(other) ) from None - else: - raise TypeError( - f"Can only connect to {self.connection_partner_type.__name__} " - f"objects, but {self.full_label} ({self.__class__.__name__}) " - f"got {other} ({type(other)})" - ) + else: + raise TypeError( + f"Can only connect to {self.connection_conjugate()} " + f"objects, but {self.full_label} ({self.__class__}) " + f"got {other} ({type(other)})" + ) + + def _valid_connection(self, other: ConjugateType) -> bool: + """ + Logic for determining if a connection to a conjugate partner is valid. - def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: + Override in child classes as necessary. + """ + return True + + def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: + return ( + f"The channel {other.full_label} ({other.__class__}) has the " + f"correct type ({self.connection_conjugate()}) to connect with " + f"{self.full_label} ({self.__class__}), but is not a valid " + f"connection." + ) + + def disconnect(self, *others: ConjugateType) -> list[tuple[Self, ConjugateType]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -182,7 +190,7 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Channel, Channel]]: + def disconnect_all(self) -> list[tuple[Self, ConjugateType]]: """ Disconnect from all other channels currently in the connections list. """ @@ -199,10 +207,10 @@ def __iter__(self): return self.connections.__iter__() @property - def channel(self) -> Channel: + def channel(self) -> Self: return self - def copy_connections(self, other: Channel) -> None: + def copy_connections(self, other: Self) -> None: """ Adds all the connections in another channel to this channel's connections. @@ -235,6 +243,18 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) +class FlavorChannel(Channel[FlavorType], ABC): + """Abstract base for all flavor-specific channels.""" + + +class InputChannel(Channel[OutputType], ABC): + """Mixin for input channels.""" + + +class OutputChannel(Channel[InputType], ABC): + """Mixin for output channels.""" + + class NotData(metaclass=Singleton): """ This class exists purely to initialize data channel values where no default value @@ -258,7 +278,7 @@ def __bool__(self): NOT_DATA = NotData() -class DataChannel(Channel, ABC): +class DataChannel(FlavorChannel["DataChannel"], typing.Generic[ReceiverType], ABC): """ Data channels control the flow of data on the graph. @@ -331,7 +351,7 @@ class DataChannel(Channel, ABC): when this channel is a value receiver. This can potentially be expensive, so consider deactivating strict hints everywhere for production runs. (Default is True, raise exceptions when type hints get violated.) - value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of + value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of the same class whose value will always get updated when this channel's value gets updated. """ @@ -343,7 +363,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: InputData | None = None, + value_receiver: ReceiverType | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -352,7 +372,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver = value_receiver + self.value_receiver: ReceiverType = value_receiver @property def value(self): @@ -379,7 +399,7 @@ def _type_check_new_value(self, new_value): ) @property - def value_receiver(self) -> InputData | OutputData | None: + def value_receiver(self) -> Self | None: """ Another data channel of the same type to whom new values are always pushed (without type checking of any sort, not even when forming the couple!) @@ -390,7 +410,7 @@ def value_receiver(self) -> InputData | OutputData | None: return self._value_receiver @value_receiver.setter - def value_receiver(self, new_partner: InputData | OutputData | None): + def value_receiver(self, new_partner: Self | None): if new_partner is not None: if not isinstance(new_partner, self.__class__): raise TypeError( @@ -446,26 +466,44 @@ def _has_hint(self) -> bool: return self.type_hint is not None def _valid_connection(self, other: DataChannel) -> bool: - if super()._valid_connection(other): - if self._both_typed(other): - out, inp = self._figure_out_who_is_who(other) - if not inp.strict_hints: - return True - else: - return type_hint_is_as_or_more_specific_than( - out.type_hint, inp.type_hint - ) - else: - # If either is untyped, don't do type checking + if self._both_typed(other): + out, inp = self._figure_out_who_is_who(other) + if not inp.strict_hints: return True + else: + return type_hint_is_as_or_more_specific_than( + out.type_hint, inp.type_hint + ) else: - return False + # If either is untyped, don't do type checking + return True + + def _connection_conjugate_failure_message(self, other: DataChannel) -> str: + msg = super()._connection_conjugate_failure_message(other) + msg += ( + f"Please check type hints, etc. {other.full_label}.type_hint = " + f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" + ) + return msg def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint - def _figure_out_who_is_who(self, other: DataChannel) -> (OutputData, InputData): - return (self, other) if isinstance(self, OutputData) else (other, self) + def _figure_out_who_is_who( + self, other: DataChannel + ) -> tuple[OutputData, InputData]: + if isinstance(self, InputData) and isinstance(other, OutputData): + return other, self + elif isinstance(self, OutputData) and isinstance(other, InputData): + return self, other + else: + raise ChannelError( + f"This should be unreachable; data channel conjugate pairs should " + f"always be input/output, but got {type(self)} for {self.full_label} " + f"and {type(other)} for {other.full_label}. If you don't believe you " + f"are responsible for this error, please contact the maintainers via " + f"GitHub." + ) def __str__(self): return str(self.value) @@ -489,9 +527,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel): - @property - def connection_partner_type(self): +class InputData(DataChannel["InputData"], InputChannel["OutputData"]): + + @classmethod + def connection_conjugate(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -528,13 +567,16 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel): - @property - def connection_partner_type(self): +class OutputData(DataChannel["OutputData"], OutputChannel["InputData"]): + @classmethod + def connection_conjugate(cls) -> type[InputData]: return InputData -class SignalChannel(Channel, ABC): +SignalType = typing.TypeVar("SignalType", bound="SignalChannel") + + +class SignalChannel(FlavorChannel[SignalType], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -558,16 +600,13 @@ class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel): - @property - def connection_partner_type(self): - return OutputSignal +class InputSignal(SignalChannel["OutputSignal"], InputChannel["OutputSignal"]): def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): """ Make a new input signal channel. @@ -589,6 +628,10 @@ def __init__( f"all args are optional: {self._all_args_arg_optional(callback)} " ) + @classmethod + def connection_conjugate(cls) -> type[OutputSignal]: + return OutputSignal + def _is_method_on_owner(self, callback): try: return callback == getattr(self.owner, callback.__name__) @@ -614,7 +657,7 @@ def _has_required_args(func): ) @property - def callback(self) -> callable: + def callback(self) -> typing.Callable: return getattr(self.owner, self._callback) def __call__(self, other: OutputSignal | None = None) -> None: @@ -637,19 +680,20 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() - def __call__(self, other: OutputSignal) -> None: + def __call__(self, other: OutputSignal | None = None) -> None: """ Fire callback iff you have received at least one signal from each of your current connections. Resets the collection of received signals when firing. """ - self.received_signals.update([other.scoped_label]) + if isinstance(other, OutputSignal): + self.received_signals.update([other.scoped_label]) if ( len( {c.scoped_label for c in self.connections}.difference( @@ -673,9 +717,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel): - @property - def connection_partner_type(self): +class OutputSignal(SignalChannel["InputSignal"], OutputChannel["InputSignal"]): + + @classmethod + def connection_conjugate(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/pyiron_workflow/compatibility.py b/pyiron_workflow/compatibility.py new file mode 100644 index 00000000..28b7f773 --- /dev/null +++ b/pyiron_workflow/compatibility.py @@ -0,0 +1,6 @@ +from sys import version_info + +if version_info.minor < 11: + from typing_extensions import Self as Self +else: + from typing import Self as Self diff --git a/pyiron_workflow/draw.py b/pyiron_workflow/draw.py index 80e930c2..d6e47931 100644 --- a/pyiron_workflow/draw.py +++ b/pyiron_workflow/draw.py @@ -5,7 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import graphviz from pyiron_snippets.colors import SeabornColors @@ -13,7 +13,13 @@ from pyiron_workflow.channels import NotData if TYPE_CHECKING: - from pyiron_workflow.channels import Channel as WorkflowChannel + from pyiron_workflow.channels import Channel as WorkflowChannel # noqa: F401 + from pyiron_workflow.channels import ( + DataChannel as WorkflowDataChannel, # noqa: F401 + ) + from pyiron_workflow.channels import ( + SignalChannel as WorkflowSignalChannel, # noqa: F401 + ) from pyiron_workflow.io import DataIO, SignalIO from pyiron_workflow.node import Node as WorkflowNode @@ -67,7 +73,11 @@ def _to_hex(rgb: tuple[int, int, int]) -> str: def _to_rgb(hex_: str) -> tuple[int, int, int]: """Hex to RGB color codes; no alpha values.""" hex_ = hex_.lstrip("#") - return tuple(int(hex_[i : i + 2], 16) for i in (0, 2, 4)) + return ( + int(hex_[0 : 0 + 2], 16), + int(hex_[2 : 2 + 2], 16), + int(hex_[4 : 4 + 2], 16), + ) # mypy isn't smart enough to parse this as a 3-tuple from an iterator def blend_colours(color_a, color_b, fraction_a=0.5): @@ -117,14 +127,17 @@ def color(self) -> str: pass -class _Channel(WorkflowGraphvizMap, ABC): +WorkflowChannelType = TypeVar("WorkflowChannelType", bound="WorkflowChannel") + + +class _Channel(WorkflowGraphvizMap, Generic[WorkflowChannelType], ABC): """ An abstract representation for channel objects, which are "nodes" in graphviz parlance. """ - def __init__(self, parent: _IO, channel: WorkflowChannel, local_name: str): - self.channel = channel + def __init__(self, parent: _IO, channel: WorkflowChannelType, local_name: str): + self.channel: WorkflowChannelType = channel self._parent = parent self._name = self.parent.name + local_name self._label = local_name + self._build_label_suffix() @@ -153,7 +166,7 @@ def _build_label_suffix(self): return suffix @property - def parent(self) -> _IO | None: + def parent(self) -> _IO: return self._parent @property @@ -173,7 +186,7 @@ def style(self) -> str: return "filled" -class DataChannel(_Channel): +class DataChannel(_Channel["WorkflowDataChannel"]): @property def color(self) -> str: orange = "#EDB22C" @@ -190,7 +203,7 @@ def style(self) -> str: return "filled" -class SignalChannel(_Channel): +class SignalChannel(_Channel["WorkflowSignalChannel"]): @property def color(self) -> str: blue = "#21BFD8" diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index 038c4c45..983dc525 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info @@ -14,7 +15,7 @@ def result(self, timeout=None): class _CloudPickledCallable: - def __init__(self, fnc: callable): + def __init__(self, fnc: Callable): self.fnc_serial = cloudpickle.dumps(fnc) def __call__(self, /, dumped_args, dumped_kwargs): diff --git a/pyiron_workflow/find.py b/pyiron_workflow/find.py index eea058a3..7111ef47 100644 --- a/pyiron_workflow/find.py +++ b/pyiron_workflow/find.py @@ -1,3 +1,9 @@ +""" +A utility for finding public `pyiron_workflow.node.Node` objects. + +Supports the idea of node developers writing independent node packages. +""" + from __future__ import annotations import importlib.util @@ -5,23 +11,28 @@ import sys from pathlib import Path from types import ModuleType +from typing import TypeVar, cast from pyiron_workflow.node import Node +NodeType = TypeVar("NodeType", bound=Node) + def _get_subclasses( source: str | Path | ModuleType, - base_class: type, + base_class: type[NodeType], get_private: bool = False, get_abstract: bool = False, get_imports_too: bool = False, -): +) -> list[type[NodeType]]: if isinstance(source, str | Path): source = Path(source) if source.is_file(): # Load the module from the file module_name = source.stem spec = importlib.util.spec_from_file_location(module_name, str(source)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not create a ModuleSpec for {source}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) @@ -54,4 +65,4 @@ def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: """ Get a list of all public, non-abstract nodes defined in the source. """ - return _get_subclasses(source, Node) + return cast(list[type[Node]], _get_subclasses(source, Node)) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 5bb6f170..4d983996 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,7 +9,8 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any +from collections.abc import ItemsView, Iterator +from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -20,8 +21,10 @@ DataChannel, InputData, InputSignal, + InputType, OutputData, OutputSignal, + OutputType, SignalChannel, ) from pyiron_workflow.logging import logger @@ -32,8 +35,11 @@ HasRun, ) +OwnedType = TypeVar("OwnedType", bound=Channel) +OwnedConjugate = TypeVar("OwnedConjugate", bound=Channel) -class IO(HasStateDisplay, ABC): + +class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): """ IO is a convenience layer for holding and accessing multiple input/output channels. It allows key and dot-based access to the underlying channels. @@ -52,7 +58,9 @@ class IO(HasStateDisplay, ABC): be assigned with a simple `=`. """ - def __init__(self, *channels: Channel): + channel_dict: DotDict[str, OwnedType] + + def __init__(self, *channels: OwnedType) -> None: self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -63,15 +71,15 @@ def __init__(self, *channels: Channel): @property @abstractmethod - def _channel_class(self) -> type(Channel): + def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: Channel, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> Channel: + def __getattr__(self, item: str) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -81,7 +89,7 @@ def __getattr__(self, item) -> Channel: f"nor in its channels ({self.labels})" ) from key_error - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key in self.channel_dict: self._assign_value_to_existing_channel(self.channel_dict[key], value) elif isinstance(value, self._channel_class): @@ -97,34 +105,34 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: Channel, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> Channel: + def __getitem__(self, item: str) -> OwnedType: return self.__getattr__(item) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self.__setattr__(key, value) @property - def connections(self) -> list[Channel]: + def connections(self) -> list[OwnedConjugate]: """All the unique connections across all channels""" return list( {connection for channel in self for connection in channel.connections} ) @property - def connected(self): + def connected(self) -> bool: return any(c.connected for c in self) @property - def fully_connected(self): + def fully_connected(self) -> bool: return all(c.connected for c in self) - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: """ Disconnect all connections that owned channels have. @@ -138,34 +146,36 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: return destroyed_connections @property - def labels(self): + def labels(self) -> list[str]: return list(self.channel_dict.keys()) - def items(self): + def items(self) -> ItemsView[str, OwnedType]: return self.channel_dict.items() - def __iter__(self): + def __iter__(self) -> Iterator[OwnedType]: return self.channel_dict.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.channel_dict) def __dir__(self): - return set(super().__dir__() + self.labels) + return list(set(super().__dir__() + self.labels)) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__} {self.labels}" - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # Compatibility with python <3.11 return dict(self.__dict__) - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Because we override getattr, we need to use __dict__ assignment directly in # __setstate__ the same way we need it in __init__ self.__dict__["channel_dict"] = state["channel_dict"] - def display_state(self, state=None, ignore_private=True): + def display_state( + self, state: dict[str, Any] | None = None, ignore_private: bool = True + ) -> dict[str, Any]: state = dict(self.__getstate__()) if state is None else state for k, v in state["channel_dict"].items(): state[k] = v @@ -173,19 +183,27 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class DataIO(IO, ABC): +class InputsIO(IO[InputType, OutputType], ABC): + pass + + +class OutputsIO(IO[OutputType, InputType], ABC): + pass + + +class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value - def to_value_dict(self): + def to_value_dict(self) -> dict[str, Any]: return {label: channel.value for label, channel in self.channel_dict.items()} - def to_list(self): + def to_list(self) -> list[Any]: """A list of channel values (order not guaranteed)""" return [channel.value for channel in self.channel_dict.values()] @property - def ready(self): + def ready(self) -> bool: return all(c.ready for c in self) def activate_strict_hints(self): @@ -195,24 +213,34 @@ def deactivate_strict_hints(self): [c.deactivate_strict_hints() for c in self] -class Inputs(DataIO): +class Inputs(InputsIO, DataIO): @property - def _channel_class(self) -> type(InputData): + def _channel_class(self) -> type[InputData]: return InputData - def fetch(self): + def fetch(self) -> None: for c in self: c.fetch() -class Outputs(DataIO): +OutputDataType = TypeVar("OutputDataType", bound=OutputData) + + +class GenericOutputs(OutputsIO, DataIO, Generic[OutputDataType], ABC): @property - def _channel_class(self) -> type(OutputData): + @abstractmethod + def _channel_class(self) -> type[OutputDataType]: + pass + + +class Outputs(GenericOutputs[OutputData]): + @property + def _channel_class(self) -> type[OutputData]: return OutputData -class SignalIO(IO, ABC): - def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: +class SignalIO(IO[SignalChannel, SignalChannel], ABC): + def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " f"which is already a {type(channel)}. Only other signal channels may be " @@ -220,12 +248,12 @@ def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: ) -class InputSignals(SignalIO): +class InputSignals(InputsIO, SignalIO): @property - def _channel_class(self) -> type(InputSignal): + def _channel_class(self) -> type[InputSignal]: return InputSignal - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """Disconnect all `run` and `accumulate_and_run` signals, if they exist.""" disconnected = [] with contextlib.suppress(AttributeError): @@ -235,9 +263,9 @@ def disconnect_run(self) -> list[tuple[Channel, Channel]]: return disconnected -class OutputSignals(SignalIO): +class OutputSignals(OutputsIO, SignalIO): @property - def _channel_class(self) -> type(OutputSignal): + def _channel_class(self) -> type[OutputSignal]: return OutputSignal @@ -250,11 +278,11 @@ class Signals(HasStateDisplay): output (OutputSignals): An empty input signals IO container. """ - def __init__(self): - self.input = InputSignals() - self.output = OutputSignals() + def __init__(self) -> None: + self.input: InputSignals = InputSignals() + self.output: OutputSignals = OutputSignals() - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ Disconnect all connections in input and output signals. @@ -264,22 +292,25 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: """ return self.input.disconnect() + self.output.disconnect() - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property - def connected(self): + def connected(self) -> bool: return self.input.connected or self.output.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return self.input.fully_connected and self.output.fully_connected - def __str__(self): + def __str__(self) -> str: return f"{str(self.input)}\n{str(self.output)}" -class HasIO(HasStateDisplay, HasLabel, HasRun, ABC): +OutputsType = TypeVar("OutputsType", bound=GenericOutputs) + + +class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC): """ A mixin for classes that provide data and signal IO. @@ -288,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, ABC): interface. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._signals = Signals() self._signals.input.run = InputSignal("run", self, self.run) @@ -303,18 +334,14 @@ def __init__(self, *args, **kwargs): def inputs(self) -> Inputs: pass + @property @abstractmethod - def data_input_locked(self) -> bool: - """ - Indicates whether data input channels should consider this owner locked to - change. - """ - # Practically, this gives a well-named interface between HasIO and everything - # to do with run status + def data_input_locked(self): + """Prevents Inputs from updating when True""" @property @abstractmethod - def outputs(self) -> Outputs: + def outputs(self) -> OutputsType: pass @property @@ -326,14 +353,14 @@ def connected(self) -> bool: return self.inputs.connected or self.outputs.connected or self.signals.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return ( self.inputs.fully_connected and self.outputs.fully_connected and self.signals.fully_connected ) - def disconnect(self): + def disconnect(self) -> list[tuple[Channel, Channel]]: """ Disconnect all connections belonging to inputs, outputs, and signals channels. @@ -347,30 +374,32 @@ def disconnect(self): destroyed_connections.extend(self.signals.disconnect()) return destroyed_connections - def activate_strict_hints(self): + def activate_strict_hints(self) -> None: """Enable type hint checks for all data IO""" self.inputs.activate_strict_hints() self.outputs.activate_strict_hints() - def deactivate_strict_hints(self): + def deactivate_strict_hints(self) -> None: """Disable type hint checks for all data IO""" self.inputs.deactivate_strict_hints() self.outputs.deactivate_strict_hints() - def _connect_output_signal(self, signal: OutputSignal): + def _connect_output_signal(self, signal: OutputSignal) -> None: self.signals.input.run.connect(signal) - def __rshift__(self, other: InputSignal | HasIO): + def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: """ Allows users to connect run and ran signals like: `first >> second`. """ other._connect_output_signal(self.signals.output.ran) return other - def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal): + def _connect_accumulating_input_signal( + self, signal: AccumulatingInputSignal + ) -> None: self.signals.output.ran.connect(signal) - def __lshift__(self, others): + def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]): """ Connect one or more `ran` signals to `accumulate_and_run` signals like: `this << some_object, another_object, or_by_channel.signals.output.ran` @@ -456,8 +485,8 @@ def copy_io( try: self._copy_values(other, fail_hard=values_fail_hard) except Exception as e: - for this, other in new_connections: - this.disconnect(other) + for owned, conjugate in new_connections: + owned.disconnect(conjugate) raise e def _copy_connections( @@ -520,7 +549,7 @@ def _copy_values( self, other: HasIO, fail_hard: bool = False, - ) -> list[tuple[Channel, Any]]: + ) -> list[tuple[DataChannel, Any]]: """ Copies all data from input and output channels in the other object onto this one. @@ -543,30 +572,39 @@ def _copy_values( list[tuple[Channel, Any]]: A list of tuples giving channels whose value has been updated and what it used to be (for reverting changes). """ + # Leverage a separate function because mypy has trouble parsing types + # if we loop over inputs and outputs at the same time + return self._copy_panel( + other, self.inputs, other.inputs, fail_hard=fail_hard + ) + self._copy_panel(other, self.outputs, other.outputs, fail_hard=fail_hard) + + def _copy_panel( + self, + other: HasIO, + my_panel: DataIO, + other_panel: DataIO, + fail_hard: bool = False, + ) -> list[tuple[DataChannel, Any]]: old_values = [] - for my_panel, other_panel in [ - (self.inputs, other.inputs), - (self.outputs, other.outputs), - ]: - for key, to_copy in other_panel.items(): - if to_copy.value is not NOT_DATA: - try: - old_value = my_panel[key].value - my_panel[key].value = to_copy.value # Gets hint-checked - old_values.append((my_panel[key], old_value)) - except Exception as e: - if fail_hard: - # If you run into trouble, unwind what you've done - for channel, value in old_values: - channel.value = value - raise ValueCopyError( - f"{self.label} could not copy values from " - f"{other.label} due to the channel {key} on " - f"{other_panel.__class__.__name__}, which holds value " - f"{to_copy.value}" - ) from e - else: - continue + for key, to_copy in other_panel.items(): + if to_copy.value is not NOT_DATA: + try: + old_value = my_panel[key].value + my_panel[key].value = to_copy.value # Gets hint-checked + old_values.append((my_panel[key], old_value)) + except Exception as e: + if fail_hard: + # If you run into trouble, unwind what you've done + for channel, value in old_values: + channel.value = value + raise ValueCopyError( + f"{self.label} could not copy values from " + f"{other.label} due to the channel {key} on " + f"{other_panel.__class__.__name__}, which holds value " + f"{to_copy.value}" + ) from e + else: + continue return old_values @property diff --git a/pyiron_workflow/mixin/display_state.py b/pyiron_workflow/mixin/display_state.py index 48309e21..fb9856a8 100644 --- a/pyiron_workflow/mixin/display_state.py +++ b/pyiron_workflow/mixin/display_state.py @@ -4,6 +4,7 @@ from abc import ABC from json import dumps +from typing import Any from pyiron_workflow.mixin.has_interface_mixins import UsesState @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC): def display_state( self, state: dict | None = None, ignore_private: bool = True - ) -> dict: + ) -> dict[str, Any]: """ A dictionary of JSON-compatible objects based on the object state (plus whatever modifications to the state the class designer has chosen to make). diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 2828ce7e..049159ee 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from pyiron_workflow.channels import Channel @@ -39,10 +39,24 @@ class HasLabel(ABC): A mixin to guarantee the label interface exists. """ + _label: str + @property - @abstractmethod def label(self) -> str: """A label for the object.""" + return self._label + + @label.setter + def label(self, new_label: str): + self._check_label(new_label) + self._label = new_label + + def _check_label(self, new_label: str) -> None: + """ + Extensible checking routine for label validity. + """ + if not isinstance(new_label, str): + raise TypeError(f"Expected a string label but got {new_label}") @property def full_label(self) -> str: @@ -53,30 +67,27 @@ def full_label(self) -> str: return self.label -class HasParent(ABC): +class HasChannel(ABC): """ - A mixin to guarantee the parent interface exists. + A mix-in class for use with the :class:`Channel` class and its children. + + This is useful for letting channels attempt to connect to non-channel objects + directly by pointing them to some channel that object holds. """ @property @abstractmethod - def parent(self) -> Any: - """A parent for the object.""" + def channel(self) -> Channel: + pass -class HasChannel(ABC): - """ - A mix-in class for use with the :class:`Channel` class. - A :class:`Channel` is able to (attempt to) connect to any child instance of :class:`HasConnection` - by looking at its :attr:`connection` attribute. +ChannelType = TypeVar("ChannelType", bound="Channel") - This is useful for letting channels attempt to connect to non-channel objects - directly by pointing them to some channel that object holds. - """ +class HasGenericChannel(HasChannel, Generic[ChannelType], ABC): @property @abstractmethod - def channel(self) -> Channel: + def channel(self) -> ChannelType: pass diff --git a/pyiron_workflow/mixin/injection.py b/pyiron_workflow/mixin/injection.py index 3b22c589..2bbf4e48 100644 --- a/pyiron_workflow/mixin/injection.py +++ b/pyiron_workflow/mixin/injection.py @@ -8,11 +8,10 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any from pyiron_workflow.channels import NOT_DATA, OutputData -from pyiron_workflow.io import HasIO, Outputs +from pyiron_workflow.io import GenericOutputs from pyiron_workflow.mixin.has_interface_mixins import HasChannel if TYPE_CHECKING: @@ -275,14 +274,7 @@ def __round__(self): return self._node_injection(Round) -class OutputsWithInjection(Outputs): +class OutputsWithInjection(GenericOutputs[OutputDataWithInjection]): @property - def _channel_class(self) -> type(OutputDataWithInjection): + def _channel_class(self) -> type[OutputDataWithInjection]: return OutputDataWithInjection - - -class HasIOWithInjection(HasIO, ABC): - @property - @abstractmethod - def outputs(self) -> OutputsWithInjection: - pass diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 08c06c47..97695bdb 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -14,7 +14,8 @@ import inspect from abc import ABC, abstractmethod -from functools import lru_cache, wraps +from collections.abc import Callable +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -75,28 +76,12 @@ def preview_outputs(cls) -> dict[str, Any]: return cls._build_outputs_preview() @classmethod - def preview_io(cls) -> DotDict[str, dict]: + def preview_io(cls) -> DotDict[str, dict[str, Any | tuple[Any, Any]]]: return DotDict( {"inputs": cls.preview_inputs(), "outputs": cls.preview_outputs()} ) -def builds_class_io(subclass_factory: callable[..., type[HasIOPreview]]): - """ - A decorator for factories producing subclasses of `HasIOPreview` to invoke - :meth:`preview_io` after the class is created, thus ensuring the IO has been - constructed at the class level. - """ - - @wraps(subclass_factory) - def wrapped(*args, **kwargs): - node_class = subclass_factory(*args, **kwargs) - node_class.preview_io() - return node_class - - return wrapped - - class ScrapesIO(HasIOPreview, ABC): """ A mixin class for scraping IO channel information from a specific class method's @@ -129,7 +114,7 @@ class ScrapesIO(HasIOPreview, ABC): @classmethod @abstractmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: """Must return a static method.""" _output_labels: ClassVar[tuple[str] | None] = None # None: scrape them @@ -139,7 +124,7 @@ def _io_defining_function(cls) -> callable: ) @classmethod - def _build_inputs_preview(cls): + def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: type_hints = cls._get_type_hints() scraped: dict[str, tuple[Any, Any]] = {} for i, (label, value) in enumerate(cls._get_input_args().items()): @@ -167,7 +152,7 @@ def _build_inputs_preview(cls): return scraped @classmethod - def _build_outputs_preview(cls): + def _build_outputs_preview(cls) -> dict[str, Any]: if cls._validate_output_labels: cls._validate() # Validate output on first call @@ -287,7 +272,7 @@ def _validate_return_count(cls): ) from type_error @staticmethod - def _io_defining_documentation(io_defining_function: callable, title: str): + def _io_defining_documentation(io_defining_function: Callable, title: str): """ A helper method for building a docstring for classes that have their IO defined by some function. diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b704abc7..aa40bd67 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -7,6 +7,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Executor as StdLibExecutor from concurrent.futures import Future, ThreadPoolExecutor from functools import partial @@ -49,16 +50,16 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): new keyword arguments. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.running = False - self.failed = False - self.executor = None - # We call it an executor, but it's just whether to use one. - # This is a simply stop-gap as we work out more sophisticated ways to reference - # (or create) an executor process without ever trying to pickle a `_thread.lock` + self.running: bool = False + self.failed: bool = False + self.executor: ( + StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict] | None + ) = None + # We call it an executor, but it can also be instructions on making one self.future: None | Future = None - self._thread_pool_sleep_time = 1e-6 + self._thread_pool_sleep_time: float = 1e-6 @abstractmethod def on_run(self, *args, **kwargs) -> Any: # callable[..., Any | tuple]: @@ -73,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]: Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs). """ - def process_run_result(self, run_output): + def process_run_result(self, run_output: Any) -> Any: """ What to _do_ with the results of :meth:`on_run` once you have them. @@ -135,7 +136,7 @@ def run( :attr:`running`. (Default is True.) """ - def _none_to_dict(inp): + def _none_to_dict(inp: dict | None) -> dict: return {} if inp is None else inp before_run_kwargs = _none_to_dict(before_run_kwargs) @@ -164,7 +165,9 @@ def _none_to_dict(inp): **run_kwargs, ) - def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]: + def _before_run( + self, /, check_readiness: bool, *args, **kwargs + ) -> tuple[bool, Any]: """ Things to do _before_ running. @@ -193,6 +196,7 @@ def _run( run_exception_kwargs: dict, run_finally_kwargs: dict, finish_run_kwargs: dict, + *args, **kwargs, ) -> Any | tuple | Future: """ @@ -253,7 +257,7 @@ def _run( ) return self.future - def _run_exception(self, /, **kwargs): + def _run_exception(self, /, *args, **kwargs): """ What to do if an exception is encountered inside :meth:`_run` or :meth:`_finish_run. @@ -261,7 +265,7 @@ def _run_exception(self, /, **kwargs): self.running = False self.failed = True - def _run_finally(self, /, **kwargs): + def _run_finally(self, /, *args, **kwargs): """ What to do after :meth:`_finish_run` (whether an exception is encountered or not), or in :meth:`_run` after an exception is encountered. @@ -275,7 +279,7 @@ def _finish_run( run_exception_kwargs: dict, run_finally_kwargs: dict, **kwargs, - ) -> Any | tuple: + ) -> Any | tuple | None: """ Switch the status, then process and return the run result. """ @@ -288,6 +292,7 @@ def _finish_run( self._run_exception(**run_exception_kwargs) if raise_run_exceptions: raise e + return None finally: self._run_finally(**run_finally_kwargs) @@ -308,7 +313,7 @@ def _readiness_error_message(self) -> str: @staticmethod def _parse_executor( - executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict), + executor: StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict], ) -> StdLibExecutor: """ If you've already got an executor, you're done. But if you get callable and diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index de083b87..5b4358e1 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -2,10 +2,11 @@ Classes for "semantic" reasoning. The motivation here is to be able to provide the object with a unique identifier -in the context of other semantic objects. Each object may have exactly one parent -and an arbitrary number of children, and each child's name must be unique in the -scope of that parent. In this way, the path from the parent-most object to any -child is completely unique. The typical filesystem on a computer is an excellent +in the context of other semantic objects. Each object may have at most one parent, +while semantic parents may have an arbitrary number of children, and each child's name +must be unique in the scope of that parent. In this way, when semantic parents are also +themselves semantic, we can build a path from the parent-most object to any child that +is completely unique. The typical filesystem on a computer is an excellent example and fulfills our requirements, the only reason we depart from it is so that we are free to have objects stored in different locations (possibly even on totally different drives or machines) belong to the same semantic group. @@ -13,17 +14,20 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path +from typing import ClassVar, Generic, TypeVar from bidict import bidict from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.has_interface_mixins import HasLabel, HasParent, UsesState +from pyiron_workflow.mixin.has_interface_mixins import HasLabel, UsesState +ParentType = TypeVar("ParentType", bound="SemanticParent") -class Semantic(UsesState, HasLabel, HasParent, ABC): + +class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): """ An object with a unique semantic path. @@ -31,46 +35,60 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): accessible. """ - semantic_delimiter = "/" + semantic_delimiter: ClassVar[str] = "/" def __init__( - self, label: str, *args, parent: SemanticParent | None = None, **kwargs + self, + *args, + label: str | None = None, + parent: ParentType | None = None, + **kwargs, ): - self._label = None + self._label = "" self._parent = None self._detached_parent_path = None - self.label = label + self.label = self.__class__.__name__ if label is None else label self.parent = parent super().__init__(*args, **kwargs) - @property - def label(self) -> str: - return self._label + @classmethod + @abstractmethod + def parent_type(cls) -> type[ParentType]: + pass - @label.setter - def label(self, new_label: str) -> None: - if not isinstance(new_label, str): - raise TypeError(f"Expected a string label but got {new_label}") + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) if self.semantic_delimiter in new_label: - raise ValueError(f"{self.semantic_delimiter} cannot be in the label") - self._label = new_label + raise ValueError( + f"Semantic delimiter {self.semantic_delimiter} cannot be in new label " + f"{new_label}" + ) @property - def parent(self) -> SemanticParent | None: + def parent(self) -> ParentType | None: return self._parent @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: + def parent(self, new_parent: ParentType | None) -> None: + self._set_parent(new_parent) + + def _set_parent(self, new_parent: ParentType | None): + """ + mypy is uncooperative with super calls for setters, so we pull the behaviour + out. + """ if new_parent is self._parent: # Exit early if nothing is changing return - if new_parent is not None and not isinstance(new_parent, SemanticParent): + if new_parent is not None and not isinstance(new_parent, self.parent_type()): raise ValueError( - f"Expected None or a {SemanticParent.__name__} for the parent of " + f"Expected None or a {self.parent_type()} for the parent of " f"{self.label}, but got {new_parent}" ) + _ensure_path_is_not_cyclic(new_parent, self) + if ( self._parent is not None and new_parent is not self._parent @@ -88,18 +106,22 @@ def semantic_path(self) -> str: The path of node labels from the graph root (parent-most node) down to this node. """ + prefix: str if self.parent is None and self.detached_parent_path is None: prefix = "" elif self.parent is None and self.detached_parent_path is not None: prefix = self.detached_parent_path elif self.parent is not None and self.detached_parent_path is None: - prefix = self.parent.semantic_path + if isinstance(self.parent, Semantic): + prefix = self.parent.semantic_path + else: + prefix = self.semantic_delimiter + self.parent.label else: raise ValueError( f"The parent and detached path should not be able to take non-None " f"values simultaneously, but got {self.parent} and " - f"{self.detached_parent_path}, respectively. Please raise an issue on GitHub " - f"outlining how your reached this state." + f"{self.detached_parent_path}, respectively. Please raise an issue on " + f"GitHub outlining how your reached this state." ) return prefix + self.semantic_delimiter + self.label @@ -126,7 +148,10 @@ def full_label(self) -> str: @property def semantic_root(self) -> Semantic: """The parent-most object in this semantic path; may be self.""" - return self.parent.semantic_root if isinstance(self.parent, Semantic) else self + if isinstance(self.parent, Semantic): + return self.parent.semantic_root + else: + return self def as_path(self, root: Path | str | None = None) -> Path: """ @@ -157,9 +182,12 @@ class CyclicPathError(ValueError): """ -class SemanticParent(Semantic, ABC): +ChildType = TypeVar("ChildType", bound=Semantic) + + +class SemanticParent(HasLabel, Generic[ChildType], ABC): """ - A semantic object with a collection of uniquely-named semantic children. + A labeled object with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -176,25 +204,42 @@ class SemanticParent(Semantic, ABC): def __init__( self, - label: str, *args, - parent: SemanticParent | None = None, strict_naming: bool = True, **kwargs, ): - self._children = bidict() + self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, parent=parent, **kwargs) + super().__init__(*args, **kwargs) + + @classmethod + @abstractmethod + def child_type(cls) -> type[ChildType]: + # Dev note: In principle, this could be a regular attribute + # However, in other situations this is precluded (e.g. in channels) + # since it would result in circular references. + # Here we favour consistency over brevity, + # and maintain the X_type() class method pattern + pass @property - def children(self) -> bidict[str, Semantic]: + def children(self) -> bidict[str, ChildType]: return self._children @property def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) - def __getattr__(self, key): + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) + if self.child_type().semantic_delimiter in new_label: + raise ValueError( + f"Child type ({self.child_type()}) semantic delimiter " + f"{self.child_type().semantic_delimiter} cannot be in new label " + f"{new_label}" + ) + + def __getattr__(self, key) -> ChildType: try: return self._children[key] except KeyError as key_error: @@ -210,7 +255,7 @@ def __getattr__(self, key): def __iter__(self): return self.children.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.children) def __dir__(self): @@ -218,15 +263,15 @@ def __dir__(self): def add_child( self, - child: Semantic, + child: ChildType, label: str | None = None, strict_naming: bool | None = None, - ) -> Semantic: + ) -> ChildType: """ Add a child, optionally assigning it a new label in the process. Args: - child (Semantic): The child to add. + child (ChildType): The child to add. label (str|None): A (potentially) new label to assign the child. (Default is None, leave the child's label alone.) strict_naming (bool|None): Whether to append a suffix to the label if @@ -234,7 +279,7 @@ def add_child( use the class-level flag.) Returns: - (Semantic): The child being added. + (ChildType): The child being added. Raises: TypeError: When the child is not of an allowed class. @@ -244,19 +289,13 @@ def add_child( `strict_naming` is true. """ - if not isinstance(child, Semantic): + if not isinstance(child, self.child_type()): raise TypeError( - f"{self.label} expected a new child of type {Semantic.__name__} " + f"{self.label} expected a new child of type {self.child_type()} " f"but got {child}" ) - if isinstance(child, ParentMost): - raise ParentMostError( - f"{child.label} is {ParentMost.__name__} and may only take None as a " - f"parent but was added as a child to {self.label}" - ) - - self._ensure_path_is_not_cyclic(self, child) + _ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -277,19 +316,7 @@ def add_child( child.parent = self return child - @staticmethod - def _ensure_path_is_not_cyclic(parent: SemanticParent | None, child: Semantic): - if parent is not None and parent.semantic_path.startswith( - child.semantic_path + child.semantic_delimiter - ): - raise CyclicPathError( - f"{parent.label} cannot be the parent of {child.label}, because its " - f"semantic path is already in {child.label}'s path and cyclic paths " - f"are not allowed. (i.e. {child.semantic_path} is in " - f"{parent.semantic_path})" - ) - - def _ensure_child_has_no_other_parent(self, child: Semantic): + def _ensure_child_has_no_other_parent(self, child: Semantic) -> None: if child.parent is not None and child.parent is not self: raise ValueError( f"The child ({child.label}) already belongs to the parent " @@ -297,17 +324,17 @@ def _ensure_child_has_no_other_parent(self, child: Semantic): f"add it to this parent ({self.label})." ) - def _this_child_is_already_at_this_label(self, child: Semantic, label: str): + def _this_child_is_already_at_this_label(self, child: Semantic, label: str) -> bool: return ( label == child.label and label in self.child_labels and self.children[label] is child ) - def _this_child_is_already_at_a_different_label(self, child, label): + def _this_child_is_already_at_a_different_label(self, child, label) -> bool: return child.parent is self and label != child.label - def _get_unique_label(self, label: str, strict_naming: bool): + def _get_unique_label(self, label: str, strict_naming: bool) -> str: if label in self.__dir__(): if label in self.child_labels: if strict_naming: @@ -324,7 +351,7 @@ def _get_unique_label(self, label: str, strict_naming: bool): ) return label - def _add_suffix_to_label(self, label): + def _add_suffix_to_label(self, label: str) -> str: i = 0 new_label = label while new_label in self.__dir__(): @@ -339,29 +366,21 @@ def _add_suffix_to_label(self, label): ) return new_label - def remove_child(self, child: Semantic | str) -> Semantic: + def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): - child = self.children.pop(child) - elif isinstance(child, Semantic): + child_instance = self.children.pop(child) + elif isinstance(child, self.child_type()): self.children.inv.pop(child) + child_instance = child else: raise TypeError( f"{self.label} expected to remove a child of type str or " - f"{Semantic.__name__} but got {child}" + f"{self.child_type()} but got {child}" ) - child.parent = None + child_instance.parent = None - return child - - @property - def parent(self) -> SemanticParent | None: - return self._parent - - @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: - self._ensure_path_is_not_cyclic(new_parent, self) - super(SemanticParent, type(self)).parent.__set__(self, new_parent) + return child_instance def __getstate__(self): state = super().__getstate__() @@ -398,25 +417,13 @@ def __setstate__(self, state): child.parent = self -class ParentMostError(TypeError): - """ - To be raised when assigning a parent to a parent-most object - """ - - -class ParentMost(SemanticParent, ABC): - """ - A semantic parent that cannot have any other parent. - """ - - @property - def parent(self) -> None: - return None - - @parent.setter - def parent(self, new_parent: None): - if new_parent is not None: - raise ParentMostError( - f"{self.label} is {ParentMost.__name__} and may only take None as a " - f"parent but got {type(new_parent)}" - ) +def _ensure_path_is_not_cyclic(parent, child: Semantic) -> None: + if isinstance(parent, Semantic) and parent.semantic_path.startswith( + child.semantic_path + child.semantic_delimiter + ): + raise CyclicPathError( + f"{parent.label} cannot be the parent of {child.label}, because its " + f"semantic path is already in {child.label}'s path and cyclic paths " + f"are not allowed. (i.e. {child.semantic_path} is in " + f"{parent.semantic_path})" + ) diff --git a/pyiron_workflow/mixin/single_output.py b/pyiron_workflow/mixin/single_output.py index 4272d870..1e6dacfc 100644 --- a/pyiron_workflow/mixin/single_output.py +++ b/pyiron_workflow/mixin/single_output.py @@ -7,7 +7,8 @@ from abc import ABC, abstractmethod -from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel +from pyiron_workflow.io import HasIO +from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel from pyiron_workflow.mixin.injection import ( OutputDataWithInjection, OutputsWithInjection, @@ -18,7 +19,9 @@ class AmbiguousOutputError(ValueError): """Raised when searching for exactly one output, but multiple are found.""" -class ExploitsSingleOutput(HasLabel, HasChannel, ABC): +class ExploitsSingleOutput( + HasIO[OutputsWithInjection], HasGenericChannel[OutputDataWithInjection], ABC +): @property @abstractmethod def outputs(self) -> OutputsWithInjection: diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 3b86a5e4..07a2d7d3 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from importlib import import_module -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import cloudpickle from pyiron_snippets.colors import SeabornColors @@ -19,7 +19,6 @@ from pyiron_workflow.draw import Node as GraphvizNode from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.injection import HasIOWithInjection from pyiron_workflow.mixin.run import ReadinessError, Runnable from pyiron_workflow.mixin.semantics import Semantic from pyiron_workflow.mixin.single_output import ExploitsSingleOutput @@ -40,8 +39,7 @@ class Node( - HasIOWithInjection, - Semantic, + Semantic["Composite"], Runnable, ExploitsSingleOutput, ABC, @@ -179,8 +177,9 @@ class Node( inputs (pyiron_workflow.io.Inputs): **Abstract.** Children must define a property returning an :class:`Inputs` object. label (str): A name for the node. - outputs (pyiron_workflow.io.Outputs): **Abstract.** Children must define - a property returning an :class:`Outputs` object. + outputs (pyiron_workflow.mixin.injection.OutputsWithInjection): **Abstract.** + Children must define a property returning an :class:`OutputsWithInjection` + object. parent (pyiron_workflow.composite.Composite | None): The parent object owning this, if any. ready (bool): Whether the inputs are all ready and the node is neither @@ -297,18 +296,17 @@ def __init__( **kwargs: Interpreted as node input data, with keys corresponding to channel labels. """ - super().__init__( - label=self.__class__.__name__ if label is None else label, - parent=parent, - ) + super().__init__(label=label, parent=parent) self.checkpoint = checkpoint self.recovery: Literal["pickle"] | StorageInterface | None = "pickle" self._serialize_result = False # Advertised, but private to indicate # under-development status -- API may change to be more user-friendly self._do_clean: bool = False # Power-user override for cleaning up temporary # serialized results and empty directories (or not). - self._cached_inputs = None - self._user_data = {} # A place for power-users to bypass node-injection + self._cached_inputs: dict[str, Any] | None = None + + self._user_data: dict[str, Any] = {} + # A place for power-users to bypass node-injection self._setup_node() self._after_node_setup( @@ -319,6 +317,12 @@ def __init__( **kwargs, ) + @classmethod + def parent_type(cls) -> type[Composite]: + from pyiron_workflow.nodes.composite import Composite + + return Composite + def _setup_node(self) -> None: """ Called _before_ :meth:`Node.__init__` finishes. @@ -627,7 +631,7 @@ def run_data_tree(self, run_parent_trees_too=False) -> None: try: parent_starting_nodes = ( - self.parent.starting_nodes if self.parent is not None else None + self.parent.starting_nodes if self.parent is not None else [] ) # We need these for state recovery later, even if we crash if len(data_tree_starters) == 1 and data_tree_starters[0] is self: @@ -703,7 +707,7 @@ def _outputs_to_run_return(self): return DotDict(self.outputs.to_value_dict()) @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.failed: return (self.signals.output.failed,) else: @@ -826,10 +830,8 @@ def draw( (graphviz.graphs.Digraph): The resulting graph object. """ - - if size is not None: - size = f"{size[0]},{size[1]}" - graph = GraphvizNode(self, depth=depth, rankdir=rankdir, size=size).graph + size_str = f"{size[0]},{size[1]}" if size is not None else None + graph = GraphvizNode(self, depth=depth, rankdir=rankdir, size=size_str).graph if save or view or filename is not None: directory = self.as_path() if directory is None else Path(directory) filename = self.label + "_graph" if filename is None else filename @@ -892,7 +894,7 @@ def replace_with(self, other: Node | type[Node]): def save( self, - backend: str | StorageInterface = "pickle", + backend: Literal["pickle"] | StorageInterface = "pickle", filename: str | Path | None = None, **kwargs, ): @@ -912,9 +914,9 @@ def save( ): selected_backend.save(node=self, filename=filename, **kwargs) - save.__doc__ += _save_load_warnings + save.__doc__ = cast(str, save.__doc__) + _save_load_warnings - def save_checkpoint(self, backend: str | StorageInterface = "pickle"): + def save_checkpoint(self, backend: Literal["pickle"] | StorageInterface = "pickle"): """ Triggers a save on the parent-most node. @@ -926,7 +928,7 @@ def save_checkpoint(self, backend: str | StorageInterface = "pickle"): def load( self, - backend: str | StorageInterface = "pickle", + backend: Literal["pickle"] | StorageInterface = "pickle", only_requested=False, filename: str | Path | None = None, **kwargs, @@ -971,7 +973,7 @@ def load( ) self.__setstate__(inst.__getstate__()) - load.__doc__ += _save_load_warnings + load.__doc__ = cast(str, load.__doc__) + _save_load_warnings def delete_storage( self, diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 7f745e9b..38f988aa 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -6,6 +6,7 @@ from __future__ import annotations from abc import ABC +from collections.abc import Callable from time import sleep from typing import TYPE_CHECKING, Literal @@ -19,7 +20,6 @@ if TYPE_CHECKING: from pyiron_workflow.channels import ( - Channel, InputSignal, OutputSignal, ) @@ -53,7 +53,7 @@ class FailedChildError(RuntimeError): """Raise when one or more child nodes raise exceptions.""" -class Composite(SemanticParent, HasCreator, Node, ABC): +class Composite(SemanticParent[Node], HasCreator, Node, ABC): """ A base class for nodes that have internal graph structure -- i.e. they hold a collection of child nodes and their computation is to execute that graph. @@ -142,8 +142,8 @@ def __init__( # empty but the running_children list is not super().__init__( - label, *args, + label=label, parent=parent, delete_existing_savefiles=delete_existing_savefiles, autoload=autoload, @@ -153,6 +153,10 @@ def __init__( **kwargs, ) + @classmethod + def child_type(cls) -> type[Node]: + return Node + def activate_strict_hints(self): super().activate_strict_hints() for node in self: @@ -272,12 +276,12 @@ def _get_state_from_remote_other(self, other_self): state.pop("_parent") # Got overridden to None for __getstate__, so keep local return state - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """ Disconnect all `signals.input.run` connections on all child nodes. Returns: - list[tuple[Channel, Channel]]: Any disconnected pairs. + list[tuple[InputSignal, OutputSignal]]: Any disconnected pairs. """ disconnected_pairs = [] for node in self.children.values(): @@ -299,15 +303,10 @@ def add_child( label: str | None = None, strict_naming: bool | None = None, ) -> Node: - if not isinstance(child, Node): - raise TypeError( - f"Only new {Node.__name__} instances may be added, but got " - f"{type(child)}." - ) self._cached_inputs = None # Reset cache after graph change return super().add_child(child, label=label, strict_naming=strict_naming) - def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]: + def remove_child(self, child: Node | str) -> Node: """ Remove a child from the :attr:`children` collection, disconnecting it and setting its :attr:`parent` to None. @@ -316,18 +315,18 @@ def remove_child(self, child: Node | str) -> list[tuple[Channel, Channel]]: child (Node|str): The child (or its label) to remove. Returns: - (list[tuple[Channel, Channel]]): Any connections that node had. + (Node): The (now disconnected and de-parented) (former) child node. """ child = super().remove_child(child) - disconnected = child.disconnect() + child.disconnect() if child in self.starting_nodes: self.starting_nodes.remove(child) self._cached_inputs = None # Reset cache after graph change - return disconnected + return child def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: + ) -> tuple[Node, Node]: """ Replaces a node currently owned with a new node instance. The replacement must not belong to any other parent or have any connections. @@ -349,15 +348,16 @@ def replace_child( and simply gets instantiated.) Returns: - (Node): The node that got removed + (Node, Node): The node that got removed and the new one that replaced it. """ - if isinstance(owned_node, str): - owned_node = self.children[owned_node] + owned_node_instance = ( + self.children[owned_node] if isinstance(owned_node, str) else owned_node + ) - if owned_node.parent is not self: + if owned_node_instance.parent is not self: raise ValueError( f"The node being replaced should be a child of this composite, but " - f"another parent was found: {owned_node.parent}" + f"another parent was found: {owned_node_instance.parent}" ) if isinstance(replacement, Node): @@ -368,44 +368,56 @@ def replace_child( ) if replacement.connected: raise ValueError("Replacement node must not have any connections") + replacement_node = replacement elif issubclass(replacement, Node): - replacement = replacement(label=owned_node.label) + replacement_node = replacement(label=owned_node_instance.label) else: raise TypeError( f"Expected replacement node to be a node instance or node subclass, but " f"got {replacement}" ) - replacement.copy_io(owned_node) # If the replacement is incompatible, we'll + replacement_node.copy_io( + owned_node_instance + ) # If the replacement is incompatible, we'll # fail here before we've changed the parent at all. Since the replacement was # first guaranteed to be an unconnected orphan, there is not yet any permanent # damage - is_starting_node = owned_node in self.starting_nodes + is_starting_node = owned_node_instance in self.starting_nodes # In case the replaced node interfaces with the composite's IO, catch value # links inbound_links = [ - (sending_channel, replacement.inputs[sending_channel.value_receiver.label]) + ( + sending_channel, + replacement_node.inputs[sending_channel.value_receiver.label], + ) for sending_channel in self.inputs - if sending_channel.value_receiver in owned_node.inputs + if sending_channel.value_receiver in owned_node_instance.inputs ] outbound_links = [ - (replacement.outputs[sending_channel.label], sending_channel.value_receiver) - for sending_channel in owned_node.outputs + ( + replacement_node.outputs[sending_channel.label], + sending_channel.value_receiver, + ) + for sending_channel in owned_node_instance.outputs if sending_channel.value_receiver in self.outputs ] - self.remove_child(owned_node) - replacement.label, owned_node.label = owned_node.label, replacement.label - self.add_child(replacement) + self.remove_child(owned_node_instance) + replacement_node.label, owned_node_instance.label = ( + owned_node_instance.label, + replacement_node.label, + ) + self.add_child(replacement_node) if is_starting_node: - self.starting_nodes.append(replacement) + self.starting_nodes.append(replacement_node) for sending_channel, receiving_channel in inbound_links + outbound_links: sending_channel.value_receiver = receiving_channel # Clear caches self._cached_inputs = None - replacement._cached_inputs = None + replacement_node._cached_inputs = None - return owned_node + return owned_node_instance, replacement_node def executor_shutdown(self, wait=True, *, cancel_futures=False): """ @@ -419,8 +431,6 @@ def executor_shutdown(self, wait=True, *, cancel_futures=False): def __setattr__(self, key: str, node: Node): if isinstance(node, Composite) and key in ["_parent", "parent"]: # This is an edge case for assigning a node to an attribute - # We either defer to the setter with super, or directly assign the private - # variable (as requested in the setter) super().__setattr__(key, node) elif isinstance(node, Node): self.add_child(node, label=key) @@ -450,7 +460,7 @@ def graph_as_dict(self) -> dict: return _get_graph_as_dict(self) def _get_connections_as_strings( - self, panel_getter: callable + self, panel_getter: Callable ) -> list[tuple[tuple[str, str], tuple[str, str]]]: """ Connections between children in string representation based on labels. @@ -520,8 +530,8 @@ def __setstate__(self, state): def _restore_connections_from_strings( nodes: dict[str, Node] | DotDict[str, Node], connections: list[tuple[tuple[str, str], tuple[str, str]]], - input_panel_getter: callable, - output_panel_getter: callable, + input_panel_getter: Callable, + output_panel_getter: Callable, ) -> None: """ Set connections among a dictionary of nodes. diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index af23faba..bc408821 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -154,6 +154,7 @@ class For(Composite, StaticNode, ABC): _iter_on: ClassVar[tuple[str, ...]] = () _zip_on: ClassVar[tuple[str, ...]] = () _output_as_dataframe: ClassVar[bool] = True + _output_column_map: ClassVar[dict[str, str]] = {} def __init_subclass__(cls, output_column_map=None, **kwargs): super().__init_subclass__(**kwargs) @@ -182,18 +183,16 @@ def __init_subclass__(cls, output_column_map=None, **kwargs): f"{cls._body_node_class.__name__} has no such outputs." ) - cls._output_column_map = output_column_map + cls._output_column_map = {} if output_column_map is None else output_column_map @classmethod - @property @lru_cache(maxsize=1) def output_column_map(cls) -> dict[str, str]: """ How to transform body node output labels to dataframe column names. """ map_ = {k: k for k in cls._body_node_class.preview_outputs()} - overrides = {} if cls._output_column_map is None else cls._output_column_map - for body_label, column_name in overrides.items(): + for body_label, column_name in cls._output_column_map.items(): map_[body_label] = column_name return map_ @@ -311,7 +310,7 @@ def _collect_output_as_dataframe(self, iter_maps): row_collector.inputs[label] = self.children[label][i] for label, body_out in self[self._body_name(n)].outputs.items(): - row_collector.inputs[self.output_column_map[label]] = body_out + row_collector.inputs[self.output_column_map()[label]] = body_out self.dataframe.inputs[f"row_{n}"] = row_collector @@ -324,7 +323,7 @@ def _build_row_collector_node(self, row_number) -> InputsToDict: # Outputs row_specification.update( { - self.output_column_map[key]: (hint, NOT_DATA) + self.output_column_map()[key]: (hint, NOT_DATA) for key, hint in self._body_node_class.preview_outputs().items() } ) @@ -339,7 +338,7 @@ def column_collector_name(s: str): return f"column_collector_{s}" for label, hint in self._body_node_class.preview_outputs().items(): - mapped_label = self.output_column_map[label] + mapped_label = self.output_column_map()[label] column_collector = inputs_to_list( n_rows, label=column_collector_name(mapped_label), @@ -377,7 +376,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: for label, (hint, default) in cls._body_node_class.preview_inputs().items(): # TODO: Leverage hint and default, listing if it's looped on if label in cls._zip_on + cls._iter_on: - hint = list if hint is None else list[hint] + hint = list if hint is None else list[hint] # type: ignore[valid-type] default = NOT_DATA # TODO: Figure out a generator pattern to get lists preview[label] = (hint, default) return preview @@ -393,11 +392,11 @@ def _build_outputs_preview(cls) -> dict[str, Any]: _default, ) in cls._body_node_class.preview_inputs().items(): if label in cls._zip_on + cls._iter_on: - hint = list if hint is None else list[hint] + hint = list if hint is None else list[hint] # type: ignore[valid-type] preview[label] = hint for label, hint in cls._body_node_class.preview_outputs().items(): - preview[cls.output_column_map[label]] = ( - list if hint is None else list[hint] + preview[cls.output_column_map()[label]] = ( + list if hint is None else list[hint] # type: ignore[valid-type] ) return preview @@ -510,7 +509,7 @@ def for_node_factory( output_column_map: dict | None = None, use_cache: bool = True, /, -): +) -> type[For]: combined_docstring = ( "For node docstring:\n" + (For.__doc__ if For.__doc__ is not None else "") @@ -521,7 +520,7 @@ def for_node_factory( iter_on = (iter_on,) if isinstance(iter_on, str) else iter_on zip_on = (zip_on,) if isinstance(zip_on, str) else zip_on - return ( + return ( # type: ignore[return-value] _for_node_class_name(body_node_class, iter_on, zip_on, output_as_dataframe), (For,), { @@ -648,6 +647,8 @@ def for_node( Index(['a', 'b', 'c', 'd', 'out_a', 'out_b', 'out_c', 'out_d', 'e'], dtype='object') """ + iter_on = (iter_on,) if isinstance(iter_on, str) else iter_on + zip_on = (zip_on,) if isinstance(zip_on, str) else zip_on for_node_factory.clear( _for_node_class_name(body_node_class, iter_on, zip_on, output_as_dataframe) ) diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index cd3d9f31..877f3b28 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource from typing import Any @@ -300,11 +301,11 @@ class Function(StaticNode, ScrapesIO, ABC): @staticmethod @abstractmethod - def node_function(**kwargs) -> callable: + def node_function(**kwargs) -> Callable: """What the node _does_.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.node_function @classmethod @@ -351,12 +352,12 @@ def _extra_info(cls) -> str: @classfactory def function_node_factory( - node_function: callable, + node_function: Callable, validate_output_labels: bool, use_cache: bool = True, /, *output_labels, -): +) -> type[Function]: """ Create a new :class:`Function` node class based on the given node function. This function gets executed on each :meth:`run` of the resulting function. @@ -372,7 +373,7 @@ def function_node_factory( Returns: type[Node]: A new node class. """ - return ( + return ( # type: ignore[return-value] node_function.__name__, (Function,), # Define parentage { @@ -429,7 +430,7 @@ def decorator(node_function): def function_node( - node_function: callable, + node_function: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 566605a5..ec09f8e8 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -7,13 +7,15 @@ import re from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource from typing import TYPE_CHECKING from pyiron_snippets.factory import classfactory -from pyiron_workflow.io import Inputs, Outputs +from pyiron_workflow.io import Inputs from pyiron_workflow.mixin.has_interface_mixins import HasChannel +from pyiron_workflow.mixin.injection import OutputsWithInjection from pyiron_workflow.mixin.preview import ScrapesIO from pyiron_workflow.nodes.composite import Composite from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels @@ -217,7 +219,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC): >>> # With the replace method >>> # (replacement target can be specified by label or instance, >>> # the replacing node can be specified by instance or class) - >>> replaced = adds_six_macro.replace_child(adds_six_macro.one, add_two()) + >>> replaced, _ = adds_six_macro.replace_child(adds_six_macro.one, add_two()) >>> # With the replace_with method >>> adds_six_macro.two.replace_with(add_two()) >>> # And by assignment of a compatible class to an occupied node label @@ -271,11 +273,13 @@ def _setup_node(self) -> None: @staticmethod @abstractmethod - def graph_creator(self, *args, **kwargs) -> callable: + def graph_creator( + self: Macro, *args, **kwargs + ) -> HasChannel | tuple[HasChannel, ...] | None: """Build the graph the node will run.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.graph_creator _io_defining_function_uses_self = True @@ -341,7 +345,7 @@ def inputs(self) -> Inputs: return self._inputs @property - def outputs(self) -> Outputs: + def outputs(self) -> OutputsWithInjection: return self._outputs def _parse_remotely_executed_self(self, other_self): @@ -466,12 +470,12 @@ def _extra_info(cls) -> str: @classfactory def macro_node_factory( - graph_creator: callable, + graph_creator: Callable, validate_output_labels: bool, use_cache: bool = True, /, *output_labels: str, -): +) -> type[Macro]: """ Create a new :class:`Macro` subclass using the given graph creator function. @@ -487,7 +491,7 @@ def macro_node_factory( Returns: type[Macro]: A new :class:`Macro` subclass. """ - return ( + return ( # type: ignore[return-value] graph_creator.__name__, (Macro,), # Define parentage { @@ -536,7 +540,7 @@ def decorator(graph_creator): def macro_node( - graph_creator: callable, + graph_creator: Callable[..., HasChannel | tuple[HasChannel, ...] | None], *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index 8c119944..6635a4c9 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -7,6 +7,7 @@ import os import random import shutil +from collections.abc import Callable from pathlib import Path from time import sleep @@ -62,7 +63,7 @@ def node_function(condition): return truth @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.outputs.truth.value is NOT_DATA: return super().emitting_channels elif self.outputs.truth.value: @@ -167,7 +168,7 @@ def ChangeDirectory( @as_function_node -def PureCall(fnc: callable): +def PureCall(fnc: Callable): """ Return a call without any arguments diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 97befbbb..6a4371be 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -6,6 +6,7 @@ import itertools from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import MISSING from dataclasses import dataclass as as_dataclass from typing import Any, ClassVar @@ -14,8 +15,7 @@ from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory -from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.preview import builds_class_io +from pyiron_workflow.channels import NOT_DATA, NotData from pyiron_workflow.nodes.static_io import StaticNode @@ -39,10 +39,6 @@ class FromManyInputs(Transformer, ABC): # Inputs convert to `run_args` as a value dictionary # This must be commensurate with the internal expectations of _on_run - @abstractmethod - def _on_run(self, **inputs_to_value_dict) -> Any: - """Must take inputs kwargs""" - @property def _run_args(self) -> tuple[tuple, dict]: return (), self.inputs.to_value_dict() @@ -59,13 +55,13 @@ def process_run_result(self, run_output: Any | tuple) -> Any | tuple: class ToManyOutputs(Transformer, ABC): _input_name: ClassVar[str] # Mandatory attribute for non-abstract subclasses _input_type_hint: ClassVar[Any] = None - _input_default: ClassVar[Any | NOT_DATA] = NOT_DATA + _input_default: ClassVar[Any | NotData] = NOT_DATA # _build_outputs_preview still required from parent class # Must be commensurate with the dictionary returned by transform_to_output @abstractmethod - def _on_run(self, input_object) -> callable[..., Any | tuple]: + def _on_run(self, input_object) -> Callable[..., Any | tuple]: """Must take the single object to be transformed""" @property @@ -110,10 +106,9 @@ def _build_outputs_preview(cls) -> dict[str, Any]: return {f"item_{i}": None for i in range(cls._length)} -@builds_class_io @classfactory def inputs_to_list_factory(n: int, use_cache: bool = True, /) -> type[InputsToList]: - return ( + return ( # type: ignore[return-value] f"{InputsToList.__name__}{n}", (InputsToList,), { @@ -140,13 +135,14 @@ def inputs_to_list(n: int, /, *node_args, use_cache: bool = True, **node_kwargs) InputsToList: An instance of the dynamically created :class:`InputsToList` subclass. """ - return inputs_to_list_factory(n, use_cache)(*node_args, **node_kwargs) + cls = inputs_to_list_factory(n, use_cache) + cls.preview_io() + return cls(*node_args, **node_kwargs) -@builds_class_io @classfactory def list_to_outputs_factory(n: int, use_cache: bool = True, /) -> type[ListToOutputs]: - return ( + return ( # type: ignore[return-value] f"{ListToOutputs.__name__}{n}", (ListToOutputs,), { @@ -175,21 +171,24 @@ def list_to_outputs( ListToOutputs: An instance of the dynamically created :class:`ListToOutputs` subclass. """ - return list_to_outputs_factory(n, use_cache)(*node_args, **node_kwargs) + + cls = list_to_outputs_factory(n, use_cache) + cls.preview_io() + return cls(*node_args, **node_kwargs) class InputsToDict(FromManyInputs, ABC): _output_name: ClassVar[str] = "dict" _output_type_hint: ClassVar[Any] = dict _input_specification: ClassVar[ - list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]] + list[str] | dict[str, tuple[Any | None, Any | NotData]] ] def _on_run(self, **inputs_to_value_dict): return inputs_to_value_dict @classmethod - def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NOT_DATA]]: + def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NotData]]: if isinstance(cls._input_specification, list): return {key: (None, NOT_DATA) for key in cls._input_specification} else: @@ -197,7 +196,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any | None, Any | NOT_DATA]]: @staticmethod def hash_specification( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], ): """For generating unique subclass names.""" @@ -223,7 +222,7 @@ def hash_specification( @classfactory def inputs_to_dict_factory( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], class_name_suffix: str | None, use_cache: bool = True, /, @@ -232,7 +231,7 @@ def inputs_to_dict_factory( class_name_suffix = str( InputsToDict.hash_specification(input_specification) ).replace("-", "m") - return ( + return ( # type: ignore[return-value] f"{InputsToDict.__name__}{class_name_suffix}", (InputsToDict,), { @@ -244,7 +243,7 @@ def inputs_to_dict_factory( def inputs_to_dict( - input_specification: list[str] | dict[str, tuple[Any | None, Any | NOT_DATA]], + input_specification: list[str] | dict[str, tuple[Any | None, Any | NotData]], *node_args, class_name_suffix: str | None = None, use_cache: bool = True, @@ -308,7 +307,7 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: def inputs_to_dataframe_factory( n: int, use_cache: bool = True, / ) -> type[InputsToDataframe]: - return ( + return ( # type: ignore[return-value] f"{InputsToDataframe.__name__}{n}", (InputsToDataframe,), { @@ -350,7 +349,6 @@ class DataclassNode(FromManyInputs, ABC): _output_name: ClassVar[str] = "dataclass" @classmethod - @property def _dataclass_fields(cls): return cls.dataclass.__dataclass_fields__ @@ -360,9 +358,9 @@ def _setup_node(self) -> None: for name, channel in self.inputs.items(): if ( channel.value is NOT_DATA - and self._dataclass_fields[name].default_factory is not MISSING + and self._dataclass_fields()[name].default_factory is not MISSING ): - self.inputs[name] = self._dataclass_fields[name].default_factory() + self.inputs[name] = self._dataclass_fields()[name].default_factory() def _on_run(self, **inputs_to_value_dict): return self.dataclass(**inputs_to_value_dict) @@ -376,12 +374,12 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: # Make a channel for each field return { name: (f.type, NOT_DATA if f.default is MISSING else f.default) - for name, f in cls._dataclass_fields.items() + for name, f in cls._dataclass_fields().items() } @classmethod def _extra_info(cls) -> str: - return cls.dataclass.__doc__ + return cls.dataclass.__doc__ or "" @classfactory @@ -405,7 +403,7 @@ def dataclass_node_factory( # Composition is preferable over inheritance, but we want inheritance to be possible module, qualname = dataclass.__module__, dataclass.__qualname__ dataclass.__qualname__ += ".dataclass" # So output type hints know where to find it - return ( + return ( # type: ignore[return-value] dataclass.__name__, (DataclassNode,), { diff --git a/pyiron_workflow/storage.py b/pyiron_workflow/storage.py index 679f8151..dcdb626c 100644 --- a/pyiron_workflow/storage.py +++ b/pyiron_workflow/storage.py @@ -36,7 +36,7 @@ class StorageInterface(ABC): """ @abstractmethod - def _save(self, node: Node, filename: Path, /, **kwargs): + def _save(self, node: Node, filename: Path, /, *args, **kwargs): """ Save a node to file. @@ -48,7 +48,7 @@ def _save(self, node: Node, filename: Path, /, **kwargs): """ @abstractmethod - def _load(self, filename: Path, /, **kwargs) -> Node: + def _load(self, filename: Path, /, *args, **kwargs) -> Node: """ Instantiate a node from file. @@ -61,7 +61,7 @@ def _load(self, filename: Path, /, **kwargs) -> Node: """ @abstractmethod - def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: + def _has_saved_content(self, filename: Path, /, *args, **kwargs) -> bool: """ Check for a save file matching this storage interface. @@ -74,7 +74,7 @@ def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: """ @abstractmethod - def _delete(self, filename: Path, /, **kwargs): + def _delete(self, filename: Path, /, *args, **kwargs): """ Remove an existing save-file for this backend. @@ -132,7 +132,7 @@ def has_saved_content( node: Node | None = None, filename: str | Path | None = None, **kwargs, - ): + ) -> bool: """ Check if a file has contents related to a node. @@ -168,7 +168,9 @@ def delete( if filename.parent.exists() and not any(filename.parent.iterdir()): filename.parent.rmdir() - def _parse_filename(self, node: Node | None, filename: str | Path | None = None): + def _parse_filename( + self, node: Node | None, filename: str | Path | None = None + ) -> Path: """ Make sure the node xor filename was provided, and if it's the node, convert it into a canonical filename by exploiting the node's semantic path. @@ -195,6 +197,11 @@ def _parse_filename(self, node: Node | None, filename: str | Path | None = None) f"Both the node ({node.full_label}) and filename ({filename}) were " f"specified for loading -- please only specify one or the other." ) + else: + raise AssertionError( + "This is an unreachable state -- we have covered all four cases of the " + "boolean `is (not) None` square." + ) class PickleStorage(StorageInterface): @@ -204,11 +211,11 @@ class PickleStorage(StorageInterface): def __init__(self, cloudpickle_fallback: bool = True): self.cloudpickle_fallback = cloudpickle_fallback - def _fallback(self, cpf: bool | None): + def _fallback(self, cpf: bool | None) -> bool: return self.cloudpickle_fallback if cpf is None else cpf def _save( - self, node: Node, filename: Path, cloudpickle_fallback: bool | None = None + self, node: Node, filename: Path, /, cloudpickle_fallback: bool | None = None ): if not self._fallback(cloudpickle_fallback) and not node.import_ready: raise TypeNotFoundError( @@ -236,19 +243,22 @@ def _save( if e is not None: raise e - def _load(self, filename: Path, cloudpickle_fallback: bool | None = None) -> Node: + def _load( + self, filename: Path, /, cloudpickle_fallback: bool | None = None + ) -> Node: attacks = [(self._PICKLE, pickle.load)] if self._fallback(cloudpickle_fallback): attacks += [(self._CLOUDPICKLE, cloudpickle.load)] for suffix, load_method in attacks: p = filename.with_suffix(suffix) - if p.exists(): + if p.is_file(): with open(p, "rb") as filehandle: inst = load_method(filehandle) return inst + raise FileNotFoundError(f"Could not load {filename}, no such file found.") - def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): + def _delete(self, filename: Path, /, cloudpickle_fallback: bool | None = None): suffixes = ( [self._PICKLE, self._CLOUDPICKLE] if self._fallback(cloudpickle_fallback) @@ -258,7 +268,7 @@ def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): filename.with_suffix(suffix).unlink(missing_ok=True) def _has_saved_content( - self, filename: Path, cloudpickle_fallback: bool | None = None + self, filename: Path, /, cloudpickle_fallback: bool | None = None ) -> bool: suffixes = ( [self._PICKLE, self._CLOUDPICKLE] diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index e96eeb76..650da2b6 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,12 +6,13 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten if TYPE_CHECKING: - from pyiron_workflow.channels import SignalChannel + from pyiron_workflow.channels import InputSignal, OutputSignal from pyiron_workflow.node import Node @@ -74,8 +75,8 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: ) locally_scoped_dependencies.append(upstream.owner.label) node_dependencies.extend(locally_scoped_dependencies) - node_dependencies = set(node_dependencies) - if node.label in node_dependencies: + node_dependencies_set = set(node_dependencies) + if node.label in node_dependencies_set: # the toposort library has a # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) # That self-dependency isn't caught, so we catch it manually here. @@ -84,14 +85,14 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: f"the execution of non-DAGs: {node.full_label} appears in its own " f"input." ) - digraph[node.label] = node_dependencies + digraph[node.label] = node_dependencies_set return digraph def _set_new_run_connections_with_fallback_recovery( - connection_creator: callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] -): + connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a function that takes a dictionary of unconnected nodes, connects their execution graph, and returns the new starting nodes, this wrapper makes sure that @@ -143,7 +144,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list def set_run_connections_according_to_linear_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all @@ -195,7 +196,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]: def set_run_connections_according_to_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index f8619df0..567ac7fa 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,10 +28,9 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance(type_hint, types.UnionType | typing._UnionGenericAlias): + if isinstance(type_hint, types.UnionType): return typing.get_args(type_hint) - else: - return (type_hint,) + return (type_hint,) def type_hint_is_as_or_more_specific_than(hint, other) -> bool: diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 791e17c8..aa39caff 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -10,8 +10,8 @@ from bidict import bidict -from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.mixin.semantics import ParentMost +from pyiron_workflow.io import Inputs +from pyiron_workflow.mixin.injection import OutputsWithInjection from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -20,7 +20,19 @@ from pyiron_workflow.storage import StorageInterface -class Workflow(ParentMost, Composite): +class ParentMostError(TypeError): + """ + To be raised when assigning a parent to a parent-most object + """ + + +class NoArgsError(TypeError): + """ + To be raised when *args can't be processed but are received + """ + + +class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of nodes (a subgraph) which can be dynamically modified (adding and removing nodes, @@ -213,13 +225,11 @@ def __init__( automate_execution: bool = True, **kwargs, ): - self._inputs_map = None - self._outputs_map = None - self.inputs_map = inputs_map - self.outputs_map = outputs_map + self._inputs_map = self._sanitize_map(inputs_map) + self._outputs_map = self._sanitize_map(outputs_map) self._inputs = None self._outputs = None - self.automate_execution = automate_execution + self.automate_execution: bool = automate_execution super().__init__( *nodes, @@ -252,34 +262,36 @@ def _after_node_setup( @property def inputs_map(self) -> bidict | None: - self._deduplicate_nones(self._inputs_map) + if self._inputs_map is not None: + self._deduplicate_nones(self._inputs_map) return self._inputs_map @inputs_map.setter def inputs_map(self, new_map: dict | bidict | None): - self._deduplicate_nones(new_map) - if new_map is not None: - new_map = bidict(new_map) - self._inputs_map = new_map + self._inputs_map = self._sanitize_map(new_map) @property def outputs_map(self) -> bidict | None: - self._deduplicate_nones(self._outputs_map) + if self._outputs_map is not None: + self._deduplicate_nones(self._outputs_map) return self._outputs_map @outputs_map.setter def outputs_map(self, new_map: dict | bidict | None): - self._deduplicate_nones(new_map) + self._outputs_map = self._sanitize_map(new_map) + + def _sanitize_map(self, new_map: dict | bidict | None) -> bidict | None: if new_map is not None: + if isinstance(new_map, dict): + self._deduplicate_nones(new_map) new_map = bidict(new_map) - self._outputs_map = new_map + return new_map @staticmethod - def _deduplicate_nones(some_map: dict | bidict | None) -> dict | bidict | None: - if some_map is not None: - for k, v in some_map.items(): - if v is None: - some_map[k] = (None, f"{k} disabled") + def _deduplicate_nones(some_map: dict | bidict): + for k, v in some_map.items(): + if v is None: + some_map[k] = (None, f"{k} disabled") @property def inputs(self) -> Inputs: @@ -289,7 +301,7 @@ def _build_inputs(self): return self._build_io("inputs", self.inputs_map) @property - def outputs(self) -> Outputs: + def outputs(self) -> OutputsWithInjection: return self._build_outputs() def _build_outputs(self): @@ -299,7 +311,7 @@ def _build_io( self, i_or_o: Literal["inputs", "outputs"], key_map: dict[str, str | None] | None, - ) -> Inputs | Outputs: + ) -> Inputs | OutputsWithInjection: """ Build an IO panel for exposing child node IO to the outside world at the level of the composite node's IO. @@ -315,17 +327,18 @@ def _build_io( (which normally would be exposed) by providing a string-None map. Returns: - (Inputs|Outputs): The populated panel. + (Inputs|OutputsWithInjection): The populated panel. """ key_map = {} if key_map is None else key_map - io = Inputs() if i_or_o == "inputs" else Outputs() + io = Inputs() if i_or_o == "inputs" else OutputsWithInjection() for node in self.children.values(): panel = getattr(node, i_or_o) for channel in panel: try: io_panel_key = key_map[channel.scoped_label] - if not isinstance(io_panel_key, tuple): - # Tuples indicate that the channel has been deactivated + if isinstance(io_panel_key, str): + # Otherwise it's a None-str tuple, indicaticating that the + # channel has been deactivated # This is a necessary misdirection to keep the bidict working, # as we can't simply map _multiple_ keys to `None` io[io_panel_key] = channel @@ -355,12 +368,18 @@ def _before_run( def run( self, + *args, check_readiness: bool = True, **kwargs, ): # Note: Workflows may have neither parents nor siblings, so we don't need to # worry about running their data trees first, fetching their input, nor firing # their `ran` signal, hence the change in signature from Node.run + if len(args) > 0: + raise NoArgsError( + f"{self.__class__} does not know how to process *args on run, but " + f"received {args}" + ) return super().run( run_data_tree=False, @@ -478,8 +497,10 @@ def _owned_io_panels(self) -> list[IO]: def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: - super().replace_child(owned_node=owned_node, replacement=replacement) + ) -> tuple[Node, Node]: + replaced, replacement_node = super().replace_child( + owned_node=owned_node, replacement=replacement + ) # Finally, make sure the IO is constructible with this new node, which will # catch things like incompatible IO maps @@ -490,8 +511,20 @@ def replace_child( except Exception as e: # If IO can't be successfully rebuilt using this node, revert changes and # raise the exception - self.replace_child(replacement, owned_node) # Guaranteed to work since + self.replace_child(replacement_node, replaced) # Guaranteed to work since # replacement in the other direction was already a success raise e - return owned_node + return replaced, replacement_node + + @property + def parent(self) -> None: + return None + + @parent.setter + def parent(self, new_parent: None): + if new_parent is not None: + raise ParentMostError( + f"{self.label} is a {self.__class__} and may only take None as a " + f"parent but got {type(new_parent)}" + ) diff --git a/tests/integration/test_workflow.py b/tests/integration/test_workflow.py index f1686be1..09425401 100644 --- a/tests/integration/test_workflow.py +++ b/tests/integration/test_workflow.py @@ -68,7 +68,7 @@ def node_function(value, limit=10): return value_gt_limit @property - def emitting_channels(self) -> tuple[OutputSignal]: + def emitting_channels(self) -> tuple[OutputSignal, ...]: if self.outputs.value_gt_limit.value: print(f"{self.inputs.value.value} > {self.inputs.limit.value}") return (*super().emitting_channels, self.signals.output.true) diff --git a/tests/unit/mixin/test_run.py b/tests/unit/mixin/test_run.py index 009e03b5..bfa32ece 100644 --- a/tests/unit/mixin/test_run.py +++ b/tests/unit/mixin/test_run.py @@ -8,9 +8,7 @@ class ConcreteRunnable(Runnable): - @property - def label(self) -> str: - return "child_class_with_all_methods_implemented" + _label = "child_class_with_all_methods_implemented" def on_run(self, **kwargs): return kwargs diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index dd40c5f0..fbb9bac4 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -1,21 +1,40 @@ +from __future__ import annotations + import unittest from pathlib import Path from pyiron_workflow.mixin.semantics import ( CyclicPathError, - ParentMost, Semantic, SemanticParent, ) +class ConcreteSemantic(Semantic["ConcreteParent"]): + @classmethod + def parent_type(cls) -> type[ConcreteSemanticParent]: + return ConcreteSemanticParent + + +class ConcreteParent(SemanticParent[ConcreteSemantic]): + _label = "concrete_parent_default_label" + + @classmethod + def child_type(cls) -> type[ConcreteSemantic]: + return ConcreteSemantic + + +class ConcreteSemanticParent(ConcreteParent, ConcreteSemantic): + pass + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ParentMost("root") - self.child1 = Semantic("child1", parent=self.root) - self.middle1 = SemanticParent("middle", parent=self.root) - self.middle2 = SemanticParent("middle_sub", parent=self.middle1) - self.child2 = Semantic("child2", parent=self.middle2) + self.root = ConcreteSemanticParent(label="root") + self.child1 = ConcreteSemantic(label="child1", parent=self.root) + self.middle1 = ConcreteSemanticParent(label="middle", parent=self.root) + self.middle2 = ConcreteSemanticParent(label="middle_sub", parent=self.middle1) + self.child2 = ConcreteSemantic(label="child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -35,18 +54,26 @@ def test_getattr(self): def test_label_validity(self): with self.assertRaises(TypeError, msg="Label must be a string"): - Semantic(label=123) + ConcreteSemantic(label=123) def test_label_delimiter(self): with self.assertRaises( - ValueError, msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", + ): + ConcreteSemantic(label=f"invalid{ConcreteSemantic.semantic_delimiter}label") + + non_semantic_parent = ConcreteParent() + with self.assertRaises( + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - Semantic(f"invalid{Semantic.semantic_delimiter}label") + non_semantic_parent.label = f"contains_{non_semantic_parent.child_type().semantic_delimiter}_delimiter" def test_semantic_delimiter(self): self.assertEqual( "/", - Semantic.semantic_delimiter, + ConcreteSemantic.semantic_delimiter, msg="This is just a hard-code to the current value, update it freely so " "the test passes; if it fails it's just a reminder that your change is " "not backwards compatible, and the next release number should reflect " @@ -58,18 +85,6 @@ def test_parent(self): self.assertEqual(self.child1.parent, self.root) self.assertEqual(self.root.parent, None) - with self.subTest(f"{ParentMost.__name__} exceptions"): - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't have parent" - ): - self.root.parent = SemanticParent(label="foo") - - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't be children" - ): - some_parent = SemanticParent(label="bar") - some_parent.add_child(self.root) - with self.subTest("Cyclicity exceptions"): with self.assertRaises(CyclicPathError): self.middle1.parent = self.middle2 @@ -112,7 +127,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = Semantic("orphan") + orphan = ConcreteSemantic(label="orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent" diff --git a/tests/unit/nodes/test_composite.py b/tests/unit/nodes/test_composite.py index af4ed8ee..c4487924 100644 --- a/tests/unit/nodes/test_composite.py +++ b/tests/unit/nodes/test_composite.py @@ -133,29 +133,18 @@ def test_node_removal(self): # Connect it inside the composite self.comp.foo.inputs.x = self.comp.owned.outputs.y - disconnected = self.comp.remove_child(node) + self.comp.remove_child(node) self.assertIsNone(node.parent, msg="Removal should de-parent") self.assertFalse(node.connected, msg="Removal should disconnect") - self.assertListEqual( - [(node.inputs.x, self.comp.owned.outputs.y)], - disconnected, - msg="Removal should return destroyed connections", - ) self.assertListEqual( self.comp.starting_nodes, [], msg="Removal should also remove from starting nodes", ) - - node_owned = self.comp.owned - disconnections = self.comp.remove_child(node_owned.label) - self.assertEqual( - node_owned.parent, - None, - msg="Should be able to remove nodes by label as well as by object", - ) self.assertListEqual( - [], disconnections, msg="node1 should have no connections left" + [], + self.comp.owned.connections, + msg="Remaining node should have no connections left", ) def test_label_uniqueness(self): diff --git a/tests/unit/nodes/test_transform.py b/tests/unit/nodes/test_transform.py index 569fc799..e9bef9d6 100644 --- a/tests/unit/nodes/test_transform.py +++ b/tests/unit/nodes/test_transform.py @@ -230,7 +230,7 @@ class DecoratedDCLike: prev = n_cls.preview_inputs() key = random.choice(list(prev.keys())) self.assertIs( - n_cls._dataclass_fields[key].type, + n_cls._dataclass_fields()[key].type, prev[key][0], msg="Spot-check input type hints are pulled from dataclass fields", ) @@ -238,7 +238,7 @@ class DecoratedDCLike: prev["necessary"][1], NOT_DATA, msg="Field has no default" ) self.assertEqual( - n_cls._dataclass_fields["with_default"].default, + n_cls._dataclass_fields()["with_default"].default, prev["with_default"][1], msg="Fields with default should get scraped", ) diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index eaeb4a85..bb6c4690 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pyiron_workflow.channels import ( @@ -6,6 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, + ConjugateType, InputData, InputSignal, OutputData, @@ -30,25 +33,25 @@ def data_input_locked(self): return self.locked -class InputChannel(Channel): +class DummyChannel(Channel[ConjugateType]): """Just to de-abstract the base class""" def __str__(self): return "non-abstract input" - @property - def connection_partner_type(self) -> type[Channel]: - return OutputChannel + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_conjugate()) -class OutputChannel(Channel): - """Just to de-abstract the base class""" +class InputChannel(DummyChannel["OutputChannel"]): + @classmethod + def connection_conjugate(cls) -> type[OutputChannel]: + return OutputChannel - def __str__(self): - return "non-abstract output" - @property - def connection_partner_type(self) -> type[Channel]: +class OutputChannel(DummyChannel["InputChannel"]): + @classmethod + def connection_conjugate(cls) -> type[InputChannel]: return InputChannel @@ -389,26 +392,44 @@ def test_aggregating_call(self): owner = DummyOwner() agg = AccumulatingInputSignal(label="agg", owner=owner, callback=owner.update) - with self.assertRaises( - TypeError, - msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional", - ): - agg() - out2 = OutputSignal(label="out2", owner=DummyOwner()) agg.connect(self.out, out2) + out_unrelated = OutputSignal(label="out_unrelated", owner=DummyOwner()) + + signals_sent = 0 self.assertEqual( 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, len(agg.received_signals), msg="Sanity check on initial conditions" + signals_sent, + len(agg.received_signals), + msg="Sanity check on initial conditions", ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") + agg() + signals_sent += 0 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection", + ) + agg(out_unrelated) + signals_sent += 1 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection", + ) + self.out() - self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") + signals_sent += 1 + self.assertEqual( + signals_sent, + len(agg.received_signals), + msg="Signals from other channels should be received", + ) self.assertListEqual( [0], owner.foo, @@ -416,8 +437,9 @@ def test_aggregating_call(self): ) self.out() + signals_sent += 0 self.assertEqual( - 1, + signals_sent, len(agg.received_signals), msg="Repeatedly receiving the same signal should have no effect", ) diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index 00586444..3f9cb193 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -17,7 +17,7 @@ ) -class Dummy(HasIO): +class Dummy(HasIO[Outputs]): def __init__(self, label: str | None = "has_io"): super().__init__() self._label = label diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 174013d3..bb7fd5c0 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -9,9 +9,8 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.semantics import ParentMostError from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import Workflow +from pyiron_workflow.workflow import NoArgsError, ParentMostError, Workflow ensure_tests_in_python_path() @@ -155,7 +154,7 @@ def test_io_map_bijectivity(self): self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") - def test_is_parentmost(self): + def test_takes_no_parent(self): wf = Workflow("wf") wf2 = Workflow("wf2") @@ -259,6 +258,12 @@ def sum_(a, b): return a + b wf.sum = sum_(wf.a, wf.b) + with self.assertRaises( + NoArgsError, + msg="Workflows don't know what to do with raw args, since their input " + "has no intrinsic order", + ): + wf.run(1, 2) wf.run() self.assertEqual( wf.a.outputs.y.value + wf.b.outputs.y.value,