diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 0e15f64f..e5890884 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,6 +9,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import ItemsView, Iterator from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -59,7 +60,7 @@ class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): channel_dict: DotDict[str, OwnedType] - def __init__(self, *channels: OwnedType): + def __init__(self, *channels: OwnedType) -> None: self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -74,11 +75,11 @@ def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> OwnedType: + def __getattr__(self, item: str) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -88,7 +89,7 @@ def __getattr__(self, item) -> OwnedType: f"nor in its channels ({self.labels})" ) from key_error - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key in self.channel_dict: self._assign_value_to_existing_channel(self.channel_dict[key], value) elif isinstance(value, self._channel_class): @@ -104,16 +105,16 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> OwnedType: + def __getitem__(self, item: str) -> OwnedType: return self.__getattr__(item) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self.__setattr__(key, value) @property @@ -124,11 +125,11 @@ def connections(self) -> list[OwnedConjugate]: ) @property - def connected(self): + def connected(self) -> bool: return any(c.connected for c in self) @property - def fully_connected(self): + def fully_connected(self) -> bool: return all(c.connected for c in self) def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: @@ -145,34 +146,36 @@ def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: return destroyed_connections @property - def labels(self): + def labels(self) -> list[str]: return list(self.channel_dict.keys()) - def items(self): + def items(self) -> ItemsView[str, OwnedType]: return self.channel_dict.items() - def __iter__(self): + def __iter__(self) -> Iterator[OwnedType]: return self.channel_dict.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.channel_dict) def __dir__(self): - return set(super().__dir__() + self.labels) + return list(set(super().__dir__() + self.labels)) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__} {self.labels}" - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # Compatibility with python <3.11 return dict(self.__dict__) - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Because we override getattr, we need to use __dict__ assignment directly in # __setstate__ the same way we need it in __init__ self.__dict__["channel_dict"] = state["channel_dict"] - def display_state(self, state=None, ignore_private=True): + def display_state( + self, state: dict[str, Any] | None = None, ignore_private: bool = True + ) -> dict[str, Any]: state = dict(self.__getstate__()) if state is None else state for k, v in state["channel_dict"].items(): state[k] = v @@ -192,15 +195,15 @@ class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value - def to_value_dict(self): + def to_value_dict(self) -> dict[str, Any]: return {label: channel.value for label, channel in self.channel_dict.items()} - def to_list(self): + def to_list(self) -> list[Any]: """A list of channel values (order not guaranteed)""" return [channel.value for channel in self.channel_dict.values()] @property - def ready(self): + def ready(self) -> bool: return all(c.ready for c in self) def activate_strict_hints(self): @@ -215,7 +218,7 @@ class Inputs(InputsIO, DataIO): def _channel_class(self) -> type[InputData]: return InputData - def fetch(self): + def fetch(self) -> None: for c in self: c.fetch() @@ -237,7 +240,7 @@ def _channel_class(self) -> type[OutputData]: class SignalIO(IO[SignalChannel, SignalChannel], ABC): - def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: + def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " f"which is already a {type(channel)}. Only other signal channels may be " @@ -275,9 +278,9 @@ class Signals(HasStateDisplay): output (OutputSignals): An empty input signals IO container. """ - def __init__(self): - self.input = InputSignals() - self.output = OutputSignals() + def __init__(self) -> None: + self.input: InputSignals = InputSignals() + self.output: OutputSignals = OutputSignals() def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ @@ -293,14 +296,14 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property - def connected(self): + def connected(self) -> bool: return self.input.connected or self.output.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return self.input.fully_connected and self.output.fully_connected - def __str__(self): + def __str__(self) -> str: return f"{str(self.input)}\n{str(self.output)}" @@ -316,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC): interface. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._signals = Signals() self._signals.input.run = InputSignal("run", self, self.run) @@ -375,17 +378,17 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: destroyed_connections.extend(self.signals.disconnect()) return destroyed_connections - def activate_strict_hints(self): + def activate_strict_hints(self) -> None: """Enable type hint checks for all data IO""" self.inputs.activate_strict_hints() self.outputs.activate_strict_hints() - def deactivate_strict_hints(self): + def deactivate_strict_hints(self) -> None: """Disable type hint checks for all data IO""" self.inputs.deactivate_strict_hints() self.outputs.deactivate_strict_hints() - def _connect_output_signal(self, signal: OutputSignal): + def _connect_output_signal(self, signal: OutputSignal) -> None: self.signals.input.run.connect(signal) def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: @@ -395,10 +398,12 @@ def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: other._connect_output_signal(self.signals.output.ran) return other - def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal): + def _connect_accumulating_input_signal( + self, signal: AccumulatingInputSignal + ) -> None: self.signals.output.ran.connect(signal) - def __lshift__(self, others): + def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]): """ Connect one or more `ran` signals to `accumulate_and_run` signals like: `this << some_object, another_object, or_by_channel.signals.output.ran` diff --git a/pyiron_workflow/mixin/display_state.py b/pyiron_workflow/mixin/display_state.py index 48309e21..fb9856a8 100644 --- a/pyiron_workflow/mixin/display_state.py +++ b/pyiron_workflow/mixin/display_state.py @@ -4,6 +4,7 @@ from abc import ABC from json import dumps +from typing import Any from pyiron_workflow.mixin.has_interface_mixins import UsesState @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC): def display_state( self, state: dict | None = None, ignore_private: bool = True - ) -> dict: + ) -> dict[str, Any]: """ A dictionary of JSON-compatible objects based on the object state (plus whatever modifications to the state the class designer has chosen to make). diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b7b11c7b..ac7aae86 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -74,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]: Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs). """ - def process_run_result(self, run_output): + def process_run_result(self, run_output: Any) -> Any: """ What to _do_ with the results of :meth:`on_run` once you have them. @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict: **run_kwargs, ) - def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]: + def _before_run( + self, /, check_readiness: bool, *args, **kwargs + ) -> tuple[bool, Any]: """ Things to do _before_ running. @@ -194,6 +196,7 @@ def _run( run_exception_kwargs: dict, run_finally_kwargs: dict, finish_run_kwargs: dict, + *args, **kwargs, ) -> Any | tuple | Future: """ @@ -254,7 +257,7 @@ def _run( ) return self.future - def _run_exception(self, /, **kwargs): + def _run_exception(self, /, *args, **kwargs): """ What to do if an exception is encountered inside :meth:`_run` or :meth:`_finish_run. @@ -262,7 +265,7 @@ def _run_exception(self, /, **kwargs): self.running = False self.failed = True - def _run_finally(self, /, **kwargs): + def _run_finally(self, /, *args, **kwargs): """ What to do after :meth:`_finish_run` (whether an exception is encountered or not), or in :meth:`_run` after an exception is encountered. diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index caf44c81..27c033e0 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -304,7 +304,9 @@ def __init__( self._do_clean: bool = False # Power-user override for cleaning up temporary # serialized results and empty directories (or not). self._cached_inputs = None - self._user_data = {} # A place for power-users to bypass node-injection + + self._user_data: dict[str, Any] = {} + # A place for power-users to bypass node-injection self._setup_node() self._after_node_setup( @@ -629,7 +631,7 @@ def run_data_tree(self, run_parent_trees_too=False) -> None: try: parent_starting_nodes = ( - self.parent.starting_nodes if self.parent is not None else None + self.parent.starting_nodes if self.parent is not None else [] ) # We need these for state recovery later, even if we crash if len(data_tree_starters) == 1 and data_tree_starters[0] is self: diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 8852b426..3ae3218b 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -40,10 +40,6 @@ class FromManyInputs(Transformer, ABC): # Inputs convert to `run_args` as a value dictionary # This must be commensurate with the internal expectations of _on_run - @abstractmethod - def _on_run(self, **inputs_to_value_dict) -> Any: - """Must take inputs kwargs""" - @property def _run_args(self) -> tuple[tuple, dict]: return (), self.inputs.to_value_dict() diff --git a/pyiron_workflow/storage.py b/pyiron_workflow/storage.py index 679f8151..dcdb626c 100644 --- a/pyiron_workflow/storage.py +++ b/pyiron_workflow/storage.py @@ -36,7 +36,7 @@ class StorageInterface(ABC): """ @abstractmethod - def _save(self, node: Node, filename: Path, /, **kwargs): + def _save(self, node: Node, filename: Path, /, *args, **kwargs): """ Save a node to file. @@ -48,7 +48,7 @@ def _save(self, node: Node, filename: Path, /, **kwargs): """ @abstractmethod - def _load(self, filename: Path, /, **kwargs) -> Node: + def _load(self, filename: Path, /, *args, **kwargs) -> Node: """ Instantiate a node from file. @@ -61,7 +61,7 @@ def _load(self, filename: Path, /, **kwargs) -> Node: """ @abstractmethod - def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: + def _has_saved_content(self, filename: Path, /, *args, **kwargs) -> bool: """ Check for a save file matching this storage interface. @@ -74,7 +74,7 @@ def _has_saved_content(self, filename: Path, /, **kwargs) -> bool: """ @abstractmethod - def _delete(self, filename: Path, /, **kwargs): + def _delete(self, filename: Path, /, *args, **kwargs): """ Remove an existing save-file for this backend. @@ -132,7 +132,7 @@ def has_saved_content( node: Node | None = None, filename: str | Path | None = None, **kwargs, - ): + ) -> bool: """ Check if a file has contents related to a node. @@ -168,7 +168,9 @@ def delete( if filename.parent.exists() and not any(filename.parent.iterdir()): filename.parent.rmdir() - def _parse_filename(self, node: Node | None, filename: str | Path | None = None): + def _parse_filename( + self, node: Node | None, filename: str | Path | None = None + ) -> Path: """ Make sure the node xor filename was provided, and if it's the node, convert it into a canonical filename by exploiting the node's semantic path. @@ -195,6 +197,11 @@ def _parse_filename(self, node: Node | None, filename: str | Path | None = None) f"Both the node ({node.full_label}) and filename ({filename}) were " f"specified for loading -- please only specify one or the other." ) + else: + raise AssertionError( + "This is an unreachable state -- we have covered all four cases of the " + "boolean `is (not) None` square." + ) class PickleStorage(StorageInterface): @@ -204,11 +211,11 @@ class PickleStorage(StorageInterface): def __init__(self, cloudpickle_fallback: bool = True): self.cloudpickle_fallback = cloudpickle_fallback - def _fallback(self, cpf: bool | None): + def _fallback(self, cpf: bool | None) -> bool: return self.cloudpickle_fallback if cpf is None else cpf def _save( - self, node: Node, filename: Path, cloudpickle_fallback: bool | None = None + self, node: Node, filename: Path, /, cloudpickle_fallback: bool | None = None ): if not self._fallback(cloudpickle_fallback) and not node.import_ready: raise TypeNotFoundError( @@ -236,19 +243,22 @@ def _save( if e is not None: raise e - def _load(self, filename: Path, cloudpickle_fallback: bool | None = None) -> Node: + def _load( + self, filename: Path, /, cloudpickle_fallback: bool | None = None + ) -> Node: attacks = [(self._PICKLE, pickle.load)] if self._fallback(cloudpickle_fallback): attacks += [(self._CLOUDPICKLE, cloudpickle.load)] for suffix, load_method in attacks: p = filename.with_suffix(suffix) - if p.exists(): + if p.is_file(): with open(p, "rb") as filehandle: inst = load_method(filehandle) return inst + raise FileNotFoundError(f"Could not load {filename}, no such file found.") - def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): + def _delete(self, filename: Path, /, cloudpickle_fallback: bool | None = None): suffixes = ( [self._PICKLE, self._CLOUDPICKLE] if self._fallback(cloudpickle_fallback) @@ -258,7 +268,7 @@ def _delete(self, filename: Path, cloudpickle_fallback: bool | None = None): filename.with_suffix(suffix).unlink(missing_ok=True) def _has_saved_content( - self, filename: Path, cloudpickle_fallback: bool | None = None + self, filename: Path, /, cloudpickle_fallback: bool | None = None ) -> bool: suffixes = ( [self._PICKLE, self._CLOUDPICKLE] diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 112b9dd1..8b0a707b 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -26,6 +26,12 @@ class ParentMostError(TypeError): """ +class NoArgsError(TypeError): + """ + To be raised when *args can't be processed but are received + """ + + class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of @@ -225,7 +231,7 @@ def __init__( self.outputs_map = outputs_map self._inputs = None self._outputs = None - self.automate_execution = automate_execution + self.automate_execution: bool = automate_execution super().__init__( *nodes, @@ -361,12 +367,18 @@ def _before_run( def run( self, + *args, check_readiness: bool = True, **kwargs, ): # Note: Workflows may have neither parents nor siblings, so we don't need to # worry about running their data trees first, fetching their input, nor firing # their `ran` signal, hence the change in signature from Node.run + if len(args) > 0: + raise NoArgsError( + f"{self.__class__} does not know how to process *args on run, but " + f"received {args}" + ) return super().run( run_data_tree=False, diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index f19032b7..bb7fd5c0 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -10,7 +10,7 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import ParentMostError, Workflow +from pyiron_workflow.workflow import NoArgsError, ParentMostError, Workflow ensure_tests_in_python_path() @@ -258,6 +258,12 @@ def sum_(a, b): return a + b wf.sum = sum_(wf.a, wf.b) + with self.assertRaises( + NoArgsError, + msg="Workflows don't know what to do with raw args, since their input " + "has no intrinsic order", + ): + wf.run(1, 2) wf.run() self.assertEqual( wf.a.outputs.y.value + wf.b.outputs.y.value,