Skip to content

Commit

Permalink
Generic HasIO classes to specify data output panel types (#551)
Browse files Browse the repository at this point in the history
* 🐛 hint with [] for type args

Signed-off-by: liamhuber <[email protected]>

* Make a generic version of HasChannel

Signed-off-by: liamhuber <[email protected]>

* Make HasIO generic on the output panel

Signed-off-by: liamhuber <[email protected]>

* Refactor: introduce generic data outputs panel

Signed-off-by: liamhuber <[email protected]>

* Remove unnecessary concrete class

To reduce misdirection. We barely use it in the super-class and never need to hint it. In contrast, I kept `OutputsWithInjection` around exactly because it shows up in type hints everywhere, so the shorthand version is nice to have.

Signed-off-by: liamhuber <[email protected]>

* Fix type hints and unused imports

Signed-off-by: liamhuber <[email protected]>

* More return hints (#552)

* Fix returned type of __dir__

Conventionally it returns a list, not a set, of strings

Signed-off-by: liamhuber <[email protected]>

* Add hints to io

Signed-off-by: liamhuber <[email protected]>

* Adjust run_finally signature

Signed-off-by: liamhuber <[email protected]>

* Hint user data

Signed-off-by: liamhuber <[email protected]>

* Hint Workflow.automate_execution

Signed-off-by: liamhuber <[email protected]>

* Provide a type-compliant default

It never actually matters with the current logic, because of all the checks if parent is None and the fact that it is otherwise hinted to be at least a `Composite`, but it shuts mypy up and it does zero harm.

Signed-off-by: liamhuber <[email protected]>

* black

Signed-off-by: liamhuber <[email protected]>

* `mypy` storage (#553)

* Add return hints

Signed-off-by: liamhuber <[email protected]>

* End clause with else

Signed-off-by: liamhuber <[email protected]>

* Explicitly raise an error

After narrowing our search to files, actually throw an error right away if you never found one to load.

Signed-off-by: liamhuber <[email protected]>

* Resolve method extension complaints

Signed-off-by: liamhuber <[email protected]>

* `mypy` signature compliance (#554)

* Extend runnable signatures

Signed-off-by: liamhuber <[email protected]>

* Align Workflow.run with superclass signature

Signed-off-by: liamhuber <[email protected]>

* Relax FromManyInputs._on_run constraint

It was too strict for the DataFrame subclass, so just keep the superclass reference instead of narrowing the constraints.

Signed-off-by: liamhuber <[email protected]>

* black

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber authored Jan 17, 2025
1 parent 52fe191 commit 4c9af5e
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 87 deletions.
94 changes: 56 additions & 38 deletions pyiron_workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
from abc import ABC, abstractmethod
from collections.abc import ItemsView, Iterator
from typing import Any, Generic, TypeVar

from pyiron_snippets.dotdict import DotDict
Expand Down Expand Up @@ -59,7 +60,7 @@ class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC):

channel_dict: DotDict[str, OwnedType]

def __init__(self, *channels: OwnedType):
def __init__(self, *channels: OwnedType) -> None:
self.__dict__["channel_dict"] = DotDict(
{
channel.label: channel
Expand All @@ -74,11 +75,11 @@ def _channel_class(self) -> type[OwnedType]:
pass

@abstractmethod
def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None:
def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None:
"""What to do when some non-channel value gets assigned to a channel"""
pass

def __getattr__(self, item) -> OwnedType:
def __getattr__(self, item: str) -> OwnedType:
try:
return self.channel_dict[item]
except KeyError as key_error:
Expand All @@ -88,7 +89,7 @@ def __getattr__(self, item) -> OwnedType:
f"nor in its channels ({self.labels})"
) from key_error

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any) -> None:
if key in self.channel_dict:
self._assign_value_to_existing_channel(self.channel_dict[key], value)
elif isinstance(value, self._channel_class):
Expand All @@ -104,16 +105,16 @@ def __setattr__(self, key, value):
f"attribute {key} got assigned {value} of type {type(value)}"
)

def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None:
def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None:
if isinstance(value, HasChannel):
channel.connect(value.channel)
else:
self._assign_a_non_channel_value(channel, value)

def __getitem__(self, item) -> OwnedType:
def __getitem__(self, item: str) -> OwnedType:
return self.__getattr__(item)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self.__setattr__(key, value)

@property
Expand All @@ -124,11 +125,11 @@ def connections(self) -> list[OwnedConjugate]:
)

@property
def connected(self):
def connected(self) -> bool:
return any(c.connected for c in self)

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return all(c.connected for c in self)

def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
Expand All @@ -145,34 +146,36 @@ def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
return destroyed_connections

@property
def labels(self):
def labels(self) -> list[str]:
return list(self.channel_dict.keys())

def items(self):
def items(self) -> ItemsView[str, OwnedType]:
return self.channel_dict.items()

def __iter__(self):
def __iter__(self) -> Iterator[OwnedType]:
return self.channel_dict.values().__iter__()

def __len__(self):
def __len__(self) -> int:
return len(self.channel_dict)

def __dir__(self):
return set(super().__dir__() + self.labels)
return list(set(super().__dir__() + self.labels))

def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__} {self.labels}"

def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
# Compatibility with python <3.11
return dict(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: dict[str, Any]) -> None:
# Because we override getattr, we need to use __dict__ assignment directly in
# __setstate__ the same way we need it in __init__
self.__dict__["channel_dict"] = state["channel_dict"]

def display_state(self, state=None, ignore_private=True):
def display_state(
self, state: dict[str, Any] | None = None, ignore_private: bool = True
) -> dict[str, Any]:
state = dict(self.__getstate__()) if state is None else state
for k, v in state["channel_dict"].items():
state[k] = v
Expand All @@ -192,15 +195,15 @@ class DataIO(IO[DataChannel, DataChannel], ABC):
def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None:
channel.value = value

def to_value_dict(self):
def to_value_dict(self) -> dict[str, Any]:
return {label: channel.value for label, channel in self.channel_dict.items()}

def to_list(self):
def to_list(self) -> list[Any]:
"""A list of channel values (order not guaranteed)"""
return [channel.value for channel in self.channel_dict.values()]

@property
def ready(self):
def ready(self) -> bool:
return all(c.ready for c in self)

def activate_strict_hints(self):
Expand All @@ -215,19 +218,29 @@ class Inputs(InputsIO, DataIO):
def _channel_class(self) -> type[InputData]:
return InputData

def fetch(self):
def fetch(self) -> None:
for c in self:
c.fetch()


class Outputs(OutputsIO, DataIO):
OutputDataType = TypeVar("OutputDataType", bound=OutputData)


class GenericOutputs(OutputsIO, DataIO, Generic[OutputDataType], ABC):
@property
@abstractmethod
def _channel_class(self) -> type[OutputDataType]:
pass


class Outputs(GenericOutputs[OutputData]):
@property
def _channel_class(self) -> type[OutputData]:
return OutputData


class SignalIO(IO[SignalChannel, SignalChannel], ABC):
def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None:
def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None:
raise TypeError(
f"Tried to assign {value} ({type(value)} to the {channel.full_label}, "
f"which is already a {type(channel)}. Only other signal channels may be "
Expand Down Expand Up @@ -265,9 +278,9 @@ class Signals(HasStateDisplay):
output (OutputSignals): An empty input signals IO container.
"""

def __init__(self):
self.input = InputSignals()
self.output = OutputSignals()
def __init__(self) -> None:
self.input: InputSignals = InputSignals()
self.output: OutputSignals = OutputSignals()

def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]:
"""
Expand All @@ -283,18 +296,21 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]:
return self.input.disconnect_run()

@property
def connected(self):
def connected(self) -> bool:
return self.input.connected or self.output.connected

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return self.input.fully_connected and self.output.fully_connected

def __str__(self):
def __str__(self) -> str:
return f"{str(self.input)}\n{str(self.output)}"


class HasIO(HasStateDisplay, HasLabel, HasRun, ABC):
OutputsType = TypeVar("OutputsType", bound=GenericOutputs)


class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC):
"""
A mixin for classes that provide data and signal IO.
Expand All @@ -303,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, ABC):
interface.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._signals = Signals()
self._signals.input.run = InputSignal("run", self, self.run)
Expand All @@ -329,7 +345,7 @@ def data_input_locked(self) -> bool:

@property
@abstractmethod
def outputs(self) -> Outputs:
def outputs(self) -> OutputsType:
pass

@property
Expand Down Expand Up @@ -362,17 +378,17 @@ def disconnect(self) -> list[tuple[Channel, Channel]]:
destroyed_connections.extend(self.signals.disconnect())
return destroyed_connections

def activate_strict_hints(self):
def activate_strict_hints(self) -> None:
"""Enable type hint checks for all data IO"""
self.inputs.activate_strict_hints()
self.outputs.activate_strict_hints()

def deactivate_strict_hints(self):
def deactivate_strict_hints(self) -> None:
"""Disable type hint checks for all data IO"""
self.inputs.deactivate_strict_hints()
self.outputs.deactivate_strict_hints()

def _connect_output_signal(self, signal: OutputSignal):
def _connect_output_signal(self, signal: OutputSignal) -> None:
self.signals.input.run.connect(signal)

def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
Expand All @@ -382,10 +398,12 @@ def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
other._connect_output_signal(self.signals.output.ran)
return other

def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal):
def _connect_accumulating_input_signal(
self, signal: AccumulatingInputSignal
) -> None:
self.signals.output.ran.connect(signal)

def __lshift__(self, others):
def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]):
"""
Connect one or more `ran` signals to `accumulate_and_run` signals like:
`this << some_object, another_object, or_by_channel.signals.output.ran`
Expand Down
3 changes: 2 additions & 1 deletion pyiron_workflow/mixin/display_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC
from json import dumps
from typing import Any

from pyiron_workflow.mixin.has_interface_mixins import UsesState

Expand All @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC):

def display_state(
self, state: dict | None = None, ignore_private: bool = True
) -> dict:
) -> dict[str, Any]:
"""
A dictionary of JSON-compatible objects based on the object state (plus
whatever modifications to the state the class designer has chosen to make).
Expand Down
12 changes: 2 additions & 10 deletions pyiron_workflow/mixin/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from pyiron_workflow.channels import NOT_DATA, OutputData
from pyiron_workflow.io import HasIO, Outputs
from pyiron_workflow.io import GenericOutputs
from pyiron_workflow.mixin.has_interface_mixins import HasChannel

if TYPE_CHECKING:
Expand Down Expand Up @@ -275,14 +274,7 @@ def __round__(self):
return self._node_injection(Round)


class OutputsWithInjection(Outputs):
class OutputsWithInjection(GenericOutputs[OutputDataWithInjection]):
@property
def _channel_class(self) -> type[OutputDataWithInjection]:
return OutputDataWithInjection


class HasIOWithInjection(HasIO, ABC):
@property
@abstractmethod
def outputs(self) -> OutputsWithInjection:
pass
11 changes: 7 additions & 4 deletions pyiron_workflow/mixin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]:
Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs).
"""

def process_run_result(self, run_output):
def process_run_result(self, run_output: Any) -> Any:
"""
What to _do_ with the results of :meth:`on_run` once you have them.
Expand Down Expand Up @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict:
**run_kwargs,
)

def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]:
def _before_run(
self, /, check_readiness: bool, *args, **kwargs
) -> tuple[bool, Any]:
"""
Things to do _before_ running.
Expand Down Expand Up @@ -194,6 +196,7 @@ def _run(
run_exception_kwargs: dict,
run_finally_kwargs: dict,
finish_run_kwargs: dict,
*args,
**kwargs,
) -> Any | tuple | Future:
"""
Expand Down Expand Up @@ -254,15 +257,15 @@ def _run(
)
return self.future

def _run_exception(self, /, **kwargs):
def _run_exception(self, /, *args, **kwargs):
"""
What to do if an exception is encountered inside :meth:`_run` or
:meth:`_finish_run.
"""
self.running = False
self.failed = True

def _run_finally(self, /, **kwargs):
def _run_finally(self, /, *args, **kwargs):
"""
What to do after :meth:`_finish_run` (whether an exception is encountered or
not), or in :meth:`_run` after an exception is encountered.
Expand Down
7 changes: 5 additions & 2 deletions pyiron_workflow/mixin/single_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from abc import ABC, abstractmethod

from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel, HasLabel
from pyiron_workflow.io import HasIO
from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel
from pyiron_workflow.mixin.injection import (
OutputDataWithInjection,
OutputsWithInjection,
Expand All @@ -18,7 +19,9 @@ class AmbiguousOutputError(ValueError):
"""Raised when searching for exactly one output, but multiple are found."""


class ExploitsSingleOutput(HasLabel, HasGenericChannel[OutputDataWithInjection], ABC):
class ExploitsSingleOutput(
HasIO[OutputsWithInjection], HasGenericChannel[OutputDataWithInjection], ABC
):
@property
@abstractmethod
def outputs(self) -> OutputsWithInjection:
Expand Down
Loading

0 comments on commit 4c9af5e

Please sign in to comment.