Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic HasIO classes to specify data output panel types #551

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 56 additions & 38 deletions pyiron_workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import contextlib
from abc import ABC, abstractmethod
from collections.abc import ItemsView, Iterator
from typing import Any, Generic, TypeVar

from pyiron_snippets.dotdict import DotDict
Expand Down Expand Up @@ -59,7 +60,7 @@ class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC):

channel_dict: DotDict[str, OwnedType]

def __init__(self, *channels: OwnedType):
def __init__(self, *channels: OwnedType) -> None:
self.__dict__["channel_dict"] = DotDict(
{
channel.label: channel
Expand All @@ -74,11 +75,11 @@ def _channel_class(self) -> type[OwnedType]:
pass

@abstractmethod
def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None:
def _assign_a_non_channel_value(self, channel: OwnedType, value: Any) -> None:
"""What to do when some non-channel value gets assigned to a channel"""
pass

def __getattr__(self, item) -> OwnedType:
def __getattr__(self, item: str) -> OwnedType:
try:
return self.channel_dict[item]
except KeyError as key_error:
Expand All @@ -88,7 +89,7 @@ def __getattr__(self, item) -> OwnedType:
f"nor in its channels ({self.labels})"
) from key_error

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: Any) -> None:
if key in self.channel_dict:
self._assign_value_to_existing_channel(self.channel_dict[key], value)
elif isinstance(value, self._channel_class):
Expand All @@ -104,16 +105,16 @@ def __setattr__(self, key, value):
f"attribute {key} got assigned {value} of type {type(value)}"
)

def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None:
def _assign_value_to_existing_channel(self, channel: OwnedType, value: Any) -> None:
if isinstance(value, HasChannel):
channel.connect(value.channel)
else:
self._assign_a_non_channel_value(channel, value)

def __getitem__(self, item) -> OwnedType:
def __getitem__(self, item: str) -> OwnedType:
return self.__getattr__(item)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self.__setattr__(key, value)

@property
Expand All @@ -124,11 +125,11 @@ def connections(self) -> list[OwnedConjugate]:
)

@property
def connected(self):
def connected(self) -> bool:
return any(c.connected for c in self)

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return all(c.connected for c in self)

def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
Expand All @@ -145,34 +146,36 @@ def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]:
return destroyed_connections

@property
def labels(self):
def labels(self) -> list[str]:
return list(self.channel_dict.keys())

def items(self):
def items(self) -> ItemsView[str, OwnedType]:
return self.channel_dict.items()

def __iter__(self):
def __iter__(self) -> Iterator[OwnedType]:
return self.channel_dict.values().__iter__()

def __len__(self):
def __len__(self) -> int:
return len(self.channel_dict)

def __dir__(self):
return set(super().__dir__() + self.labels)
return list(set(super().__dir__() + self.labels))

def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__} {self.labels}"

def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
# Compatibility with python <3.11
return dict(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: dict[str, Any]) -> None:
# Because we override getattr, we need to use __dict__ assignment directly in
# __setstate__ the same way we need it in __init__
self.__dict__["channel_dict"] = state["channel_dict"]

def display_state(self, state=None, ignore_private=True):
def display_state(
self, state: dict[str, Any] | None = None, ignore_private: bool = True
) -> dict[str, Any]:
state = dict(self.__getstate__()) if state is None else state
for k, v in state["channel_dict"].items():
state[k] = v
Expand All @@ -192,15 +195,15 @@ class DataIO(IO[DataChannel, DataChannel], ABC):
def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None:
channel.value = value

def to_value_dict(self):
def to_value_dict(self) -> dict[str, Any]:
return {label: channel.value for label, channel in self.channel_dict.items()}

def to_list(self):
def to_list(self) -> list[Any]:
"""A list of channel values (order not guaranteed)"""
return [channel.value for channel in self.channel_dict.values()]

@property
def ready(self):
def ready(self) -> bool:
return all(c.ready for c in self)

def activate_strict_hints(self):
Expand All @@ -215,19 +218,29 @@ class Inputs(InputsIO, DataIO):
def _channel_class(self) -> type[InputData]:
return InputData

def fetch(self):
def fetch(self) -> None:
for c in self:
c.fetch()


class Outputs(OutputsIO, DataIO):
OutputDataType = TypeVar("OutputDataType", bound=OutputData)


class GenericOutputs(OutputsIO, DataIO, Generic[OutputDataType], ABC):
@property
@abstractmethod
def _channel_class(self) -> type[OutputDataType]:
pass


class Outputs(GenericOutputs[OutputData]):
@property
def _channel_class(self) -> type[OutputData]:
return OutputData


class SignalIO(IO[SignalChannel, SignalChannel], ABC):
def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None:
def _assign_a_non_channel_value(self, channel: SignalChannel, value: Any) -> None:
raise TypeError(
f"Tried to assign {value} ({type(value)} to the {channel.full_label}, "
f"which is already a {type(channel)}. Only other signal channels may be "
Expand Down Expand Up @@ -265,9 +278,9 @@ class Signals(HasStateDisplay):
output (OutputSignals): An empty input signals IO container.
"""

def __init__(self):
self.input = InputSignals()
self.output = OutputSignals()
def __init__(self) -> None:
self.input: InputSignals = InputSignals()
self.output: OutputSignals = OutputSignals()

def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]:
"""
Expand All @@ -283,18 +296,21 @@ def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]:
return self.input.disconnect_run()

@property
def connected(self):
def connected(self) -> bool:
return self.input.connected or self.output.connected

@property
def fully_connected(self):
def fully_connected(self) -> bool:
return self.input.fully_connected and self.output.fully_connected

def __str__(self):
def __str__(self) -> str:
return f"{str(self.input)}\n{str(self.output)}"


class HasIO(HasStateDisplay, HasLabel, HasRun, ABC):
OutputsType = TypeVar("OutputsType", bound=GenericOutputs)


class HasIO(HasStateDisplay, HasLabel, HasRun, Generic[OutputsType], ABC):
"""
A mixin for classes that provide data and signal IO.
Expand All @@ -303,7 +319,7 @@ class HasIO(HasStateDisplay, HasLabel, HasRun, ABC):
interface.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._signals = Signals()
self._signals.input.run = InputSignal("run", self, self.run)
Expand All @@ -329,7 +345,7 @@ def data_input_locked(self) -> bool:

@property
@abstractmethod
def outputs(self) -> Outputs:
def outputs(self) -> OutputsType:
pass

@property
Expand Down Expand Up @@ -362,17 +378,17 @@ def disconnect(self) -> list[tuple[Channel, Channel]]:
destroyed_connections.extend(self.signals.disconnect())
return destroyed_connections

def activate_strict_hints(self):
def activate_strict_hints(self) -> None:
"""Enable type hint checks for all data IO"""
self.inputs.activate_strict_hints()
self.outputs.activate_strict_hints()

def deactivate_strict_hints(self):
def deactivate_strict_hints(self) -> None:
"""Disable type hint checks for all data IO"""
self.inputs.deactivate_strict_hints()
self.outputs.deactivate_strict_hints()

def _connect_output_signal(self, signal: OutputSignal):
def _connect_output_signal(self, signal: OutputSignal) -> None:
self.signals.input.run.connect(signal)

def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
Expand All @@ -382,10 +398,12 @@ def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO:
other._connect_output_signal(self.signals.output.ran)
return other

def _connect_accumulating_input_signal(self, signal: AccumulatingInputSignal):
def _connect_accumulating_input_signal(
self, signal: AccumulatingInputSignal
) -> None:
self.signals.output.ran.connect(signal)

def __lshift__(self, others):
def __lshift__(self, others: tuple[OutputSignal | HasIO, ...]):
"""
Connect one or more `ran` signals to `accumulate_and_run` signals like:
`this << some_object, another_object, or_by_channel.signals.output.ran`
Expand Down
3 changes: 2 additions & 1 deletion pyiron_workflow/mixin/display_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC
from json import dumps
from typing import Any

from pyiron_workflow.mixin.has_interface_mixins import UsesState

Expand All @@ -24,7 +25,7 @@ class HasStateDisplay(UsesState, ABC):

def display_state(
self, state: dict | None = None, ignore_private: bool = True
) -> dict:
) -> dict[str, Any]:
"""
A dictionary of JSON-compatible objects based on the object state (plus
whatever modifications to the state the class designer has chosen to make).
Expand Down
12 changes: 2 additions & 10 deletions pyiron_workflow/mixin/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from pyiron_workflow.channels import NOT_DATA, OutputData
from pyiron_workflow.io import HasIO, Outputs
from pyiron_workflow.io import GenericOutputs
from pyiron_workflow.mixin.has_interface_mixins import HasChannel

if TYPE_CHECKING:
Expand Down Expand Up @@ -275,14 +274,7 @@ def __round__(self):
return self._node_injection(Round)


class OutputsWithInjection(Outputs):
class OutputsWithInjection(GenericOutputs[OutputDataWithInjection]):
@property
def _channel_class(self) -> type[OutputDataWithInjection]:
return OutputDataWithInjection


class HasIOWithInjection(HasIO, ABC):
@property
@abstractmethod
def outputs(self) -> OutputsWithInjection:
pass
11 changes: 7 additions & 4 deletions pyiron_workflow/mixin/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_args(self) -> tuple[tuple, dict]:
Any data needed for :meth:`on_run`, will be passed as (*args, **kwargs).
"""

def process_run_result(self, run_output):
def process_run_result(self, run_output: Any) -> Any:
"""
What to _do_ with the results of :meth:`on_run` once you have them.
Expand Down Expand Up @@ -165,7 +165,9 @@ def _none_to_dict(inp: dict | None) -> dict:
**run_kwargs,
)

def _before_run(self, /, check_readiness, **kwargs) -> tuple[bool, Any]:
def _before_run(
self, /, check_readiness: bool, *args, **kwargs
) -> tuple[bool, Any]:
"""
Things to do _before_ running.
Expand Down Expand Up @@ -194,6 +196,7 @@ def _run(
run_exception_kwargs: dict,
run_finally_kwargs: dict,
finish_run_kwargs: dict,
*args,
**kwargs,
) -> Any | tuple | Future:
"""
Expand Down Expand Up @@ -254,15 +257,15 @@ def _run(
)
return self.future

def _run_exception(self, /, **kwargs):
def _run_exception(self, /, *args, **kwargs):
"""
What to do if an exception is encountered inside :meth:`_run` or
:meth:`_finish_run.
"""
self.running = False
self.failed = True

def _run_finally(self, /, **kwargs):
def _run_finally(self, /, *args, **kwargs):
"""
What to do after :meth:`_finish_run` (whether an exception is encountered or
not), or in :meth:`_run` after an exception is encountered.
Expand Down
7 changes: 5 additions & 2 deletions pyiron_workflow/mixin/single_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from abc import ABC, abstractmethod

from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel, HasLabel
from pyiron_workflow.io import HasIO
from pyiron_workflow.mixin.has_interface_mixins import HasGenericChannel
from pyiron_workflow.mixin.injection import (
OutputDataWithInjection,
OutputsWithInjection,
Expand All @@ -18,7 +19,9 @@ class AmbiguousOutputError(ValueError):
"""Raised when searching for exactly one output, but multiple are found."""


class ExploitsSingleOutput(HasLabel, HasGenericChannel[OutputDataWithInjection], ABC):
class ExploitsSingleOutput(
HasIO[OutputsWithInjection], HasGenericChannel[OutputDataWithInjection], ABC
):
@property
@abstractmethod
def outputs(self) -> OutputsWithInjection:
Expand Down
Loading
Loading