Skip to content

Commit

Permalink
mypy channels (#534)
Browse files Browse the repository at this point in the history
* Leverage generics for connection partners

Signed-off-by: liamhuber <[email protected]>

* Break apart connection error message

So we only reference type hints when they're there

Signed-off-by: liamhuber <[email protected]>

* Hint connections type more specifically

Signed-off-by: liamhuber <[email protected]>

* Hint disconnect more specifically

Signed-off-by: liamhuber <[email protected]>

* Use Self in disconnection hints

Signed-off-by: liamhuber <[email protected]>

* Use Self to hint value_receiver

Signed-off-by: liamhuber <[email protected]>

* Devolve responsibility for connection validity

Otherwise mypy has trouble telling that data channels really are operating on a connection partner, since the `super()` call could wind up pointing anywhere.

Signed-off-by: liamhuber <[email protected]>

* Fix typing in channel tests

Signed-off-by: liamhuber <[email protected]>

* 🐛 Return the message

Signed-off-by: liamhuber <[email protected]>

* Fix typing in figuring out who is I/O

Signed-off-by: liamhuber <[email protected]>

* Recast connection parters as class method

mypy complained about the class-level attribute access I was using to get around circular references. This is a bit more verbose, but otherwise a fine alternative.

Signed-off-by: liamhuber <[email protected]>

* Match Accumulating input signal call to parent

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber authored Jan 8, 2025
1 parent 9895187 commit c594760
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 82 deletions.
167 changes: 106 additions & 61 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pyiron_snippets.singleton import Singleton

from pyiron_workflow.compatibility import Self
from pyiron_workflow.mixin.display_state import HasStateDisplay
from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel
from pyiron_workflow.type_hinting import (
Expand All @@ -25,24 +26,34 @@
from pyiron_workflow.io import HasIO


class ChannelConnectionError(Exception):
class ChannelError(Exception):
pass


class Channel(HasChannel, HasLabel, HasStateDisplay, ABC):
class ChannelConnectionError(ChannelError):
pass


ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel")


class Channel(
HasChannel,
HasLabel,
HasStateDisplay,
typing.Generic[ConnectionPartner],
ABC
):
"""
Channels facilitate the flow of information (data or control signals) into and
out of :class:`HasIO` objects (namely nodes).
They must have an identifier (`label: str`) and belong to an
`owner: pyiron_workflow.io.HasIO`.
Non-abstract channel classes should come in input/output pairs and specify the
a necessary ancestor for instances they can connect to
(`connection_partner_type: type[Channel]`).
Channels may form (:meth:`connect`/:meth:`disconnect`) and store
(:attr:`connections: list[Channel]`) connections with other channels.
(:attr:`connections`) connections with other channels.
This connection information is reflexive, and is duplicated to be stored on _both_
channels in the form of a reference to their counterpart in the connection.
Expand All @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC):
these (dis)connections is guaranteed to be handled, and new connections are
subjected to a validity test.
In this abstract class the only requirement is that the connecting channels form a
"conjugate pair" of classes, i.e. they are children of each other's partner class
(:attr:`connection_partner_type: type[Channel]`) -- input/output connects to
output/input.
In this abstract class the only requirements are that the connecting channels form
a "conjugate pair" of classes, i.e. they are children of each other's partner class
and thus have the same "flavor", but are an input/output pair; and that they define
a string representation.
Iterating over channels yields their connections.
Expand All @@ -80,7 +91,7 @@ def __init__(
"""
self._label = label
self.owner: HasIO = owner
self.connections: list[Channel] = []
self.connections: list[ConnectionPartner] = []

@property
def label(self) -> str:
Expand All @@ -90,12 +101,12 @@ def label(self) -> str:
def __str__(self):
pass

@property
@classmethod
@abstractmethod
def connection_partner_type(self) -> type[Channel]:
def connection_partner_type(cls) -> type[ConnectionPartner]:
"""
Input and output class pairs must specify a parent class for their valid
connection partners.
The class forming a conjugate pair with this channel class -- i.e. the same
"flavor" of channel, but opposite in I/O.
"""

@property
Expand All @@ -108,21 +119,18 @@ def full_label(self) -> str:
"""A label combining the channel's usual label and its owner's semantic path"""
return f"{self.owner.full_label}.{self.label}"

def _valid_connection(self, other: Channel) -> bool:
@abstractmethod
def _valid_connection(self, other: object) -> bool:
"""
Logic for determining if a connection is valid.
Connections only allowed to instances with the right parent type -- i.e.
connection pairs should be an input/output.
"""
return isinstance(other, self.connection_partner_type)

def connect(self, *others: Channel) -> None:
def connect(self, *others: ConnectionPartner) -> None:
"""
Form a connection between this and one or more other channels.
Connections are reflexive, and should only occur between input and output
channels, i.e. they are instances of each others
:attr:`connection_partner_type`.
:meth:`connection_partner_type()`.
New connections get _prepended_ to the connection lists, so they appear first
when searching over connections.
Expand All @@ -145,24 +153,28 @@ def connect(self, *others: Channel) -> None:
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
if isinstance(other, self.connection_partner_type):
if isinstance(other, self.connection_partner_type()):
raise ChannelConnectionError(
f"The channel {other.full_label} ({other.__class__.__name__}"
f") has the correct type "
f"({self.connection_partner_type.__name__}) to connect with "
f"{self.full_label} ({self.__class__.__name__}), but is not "
f"a valid connection. Please check type hints, etc."
f"{other.full_label}.type_hint = {other.type_hint}; "
f"{self.full_label}.type_hint = {self.type_hint}"
self._connection_partner_failure_message(other)
) from None
else:
raise TypeError(
f"Can only connect to {self.connection_partner_type.__name__} "
f"objects, but {self.full_label} ({self.__class__.__name__}) "
f"Can only connect to {self.connection_partner_type()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)

def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]:
def _connection_partner_failure_message(self, other: ConnectionPartner) -> str:
return (
f"The channel {other.full_label} ({other.__class__}) has the "
f"correct type ({self.connection_partner_type()}) to connect with "
f"{self.full_label} ({self.__class__}), but is not a valid "
f"connection."
)

def disconnect(
self, *others: ConnectionPartner
) -> list[tuple[Self, ConnectionPartner]]:
"""
If currently connected to any others, removes this and the other from eachothers
respective connections lists.
Expand All @@ -182,7 +194,9 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]:
destroyed_connections.append((self, other))
return destroyed_connections

def disconnect_all(self) -> list[tuple[Channel, Channel]]:
def disconnect_all(
self
) -> list[tuple[Self, ConnectionPartner]]:
"""
Disconnect from all other channels currently in the connections list.
"""
Expand Down Expand Up @@ -257,8 +271,9 @@ def __bool__(self):

NOT_DATA = NotData()

DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel")

class DataChannel(Channel, ABC):
class DataChannel(Channel[DataConnectionPartner], ABC):
"""
Data channels control the flow of data on the graph.
Expand Down Expand Up @@ -331,7 +346,7 @@ class DataChannel(Channel, ABC):
when this channel is a value receiver. This can potentially be expensive, so
consider deactivating strict hints everywhere for production runs. (Default
is True, raise exceptions when type hints get violated.)
value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of
value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of
the same class whose value will always get updated when this channel's
value gets updated.
"""
Expand All @@ -343,7 +358,7 @@ def __init__(
default: typing.Any | None = NOT_DATA,
type_hint: typing.Any | None = None,
strict_hints: bool = True,
value_receiver: InputData | None = None,
value_receiver: Self | None = None,
):
super().__init__(label=label, owner=owner)
self._value = NOT_DATA
Expand All @@ -352,7 +367,7 @@ def __init__(
self.strict_hints = strict_hints
self.default = default
self.value = default # Implicitly type check your default by assignment
self.value_receiver = value_receiver
self.value_receiver: Self = value_receiver

@property
def value(self):
Expand All @@ -379,7 +394,7 @@ def _type_check_new_value(self, new_value):
)

@property
def value_receiver(self) -> InputData | OutputData | None:
def value_receiver(self) -> Self | None:
"""
Another data channel of the same type to whom new values are always pushed
(without type checking of any sort, not even when forming the couple!)
Expand All @@ -390,7 +405,7 @@ def value_receiver(self) -> InputData | OutputData | None:
return self._value_receiver

@value_receiver.setter
def value_receiver(self, new_partner: InputData | OutputData | None):
def value_receiver(self, new_partner: Self | None):
if new_partner is not None:
if not isinstance(new_partner, self.__class__):
raise TypeError(
Expand Down Expand Up @@ -445,8 +460,8 @@ def _value_is_data(self) -> bool:
def _has_hint(self) -> bool:
return self.type_hint is not None

def _valid_connection(self, other: DataChannel) -> bool:
if super()._valid_connection(other):
def _valid_connection(self, other: object) -> bool:
if isinstance(other, self.connection_partner_type()):
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
Expand All @@ -461,13 +476,32 @@ def _valid_connection(self, other: DataChannel) -> bool:
else:
return False

def _both_typed(self, other: DataChannel) -> bool:
def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str:
msg = super()._connection_partner_failure_message(other)
msg += (
f"Please check type hints, etc. {other.full_label}.type_hint = "
f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}"
)
return msg

def _both_typed(self, other: DataConnectionPartner | Self) -> bool:
return self._has_hint and other._has_hint

def _figure_out_who_is_who(
self, other: DataChannel
self, other: DataConnectionPartner
) -> tuple[OutputData, InputData]:
return (self, other) if isinstance(self, OutputData) else (other, self)
if isinstance(self, InputData) and isinstance(other, OutputData):
return other, self
elif isinstance(self, OutputData) and isinstance(other, InputData):
return self, other
else:
raise ChannelError(
f"This should be unreachable; data channel conjugate pairs should "
f"always be input/output, but got {type(self)} for {self.full_label} "
f"and {type(other)} for {other.full_label}. If you don't believe you "
f"are responsible for this error, please contact the maintainers via "
f"GitHub."
)

def __str__(self):
return str(self.value)
Expand All @@ -491,9 +525,10 @@ def display_state(self, state=None, ignore_private=True):
return super().display_state(state=state, ignore_private=ignore_private)


class InputData(DataChannel):
@property
def connection_partner_type(self):
class InputData(DataChannel["OutputData"]):

@classmethod
def connection_partner_type(cls) -> type[OutputData]:
return OutputData

def fetch(self) -> None:
Expand Down Expand Up @@ -530,13 +565,17 @@ def value(self, new_value):
self._value = new_value


class OutputData(DataChannel):
@property
def connection_partner_type(self):
class OutputData(DataChannel["InputData"]):
@classmethod
def connection_partner_type(cls) -> type[InputData]:
return InputData


class SignalChannel(Channel, ABC):
SignalConnectionPartner = typing.TypeVar(
"SignalConnectionPartner", bound="SignalChannel"
)

class SignalChannel(Channel[SignalConnectionPartner], ABC):
"""
Signal channels give the option control execution flow by triggering callback
functions when the channel is called.
Expand All @@ -555,15 +594,15 @@ class SignalChannel(Channel, ABC):
def __call__(self) -> None:
pass

def _valid_connection(self, other: object) -> bool:
return isinstance(other, self.connection_partner_type())


class BadCallbackError(ValueError):
pass


class InputSignal(SignalChannel):
@property
def connection_partner_type(self):
return OutputSignal
class InputSignal(SignalChannel["OutputSignal"]):

def __init__(
self,
Expand Down Expand Up @@ -591,6 +630,10 @@ def __init__(
f"all args are optional: {self._all_args_arg_optional(callback)} "
)

@classmethod
def connection_partner_type(cls) -> type[OutputSignal]:
return OutputSignal

def _is_method_on_owner(self, callback):
try:
return callback == getattr(self.owner, callback.__name__)
Expand Down Expand Up @@ -644,14 +687,15 @@ def __init__(
super().__init__(label=label, owner=owner, callback=callback)
self.received_signals: set[str] = set()

def __call__(self, other: OutputSignal) -> None:
def __call__(self, other: OutputSignal | None = None) -> None:
"""
Fire callback iff you have received at least one signal from each of your
current connections.
Resets the collection of received signals when firing.
"""
self.received_signals.update([other.scoped_label])
if isinstance(other, OutputSignal):
self.received_signals.update([other.scoped_label])
if (
len(
set(c.scoped_label for c in self.connections).difference(
Expand All @@ -675,9 +719,10 @@ def __lshift__(self, others):
other._connect_accumulating_input_signal(self)


class OutputSignal(SignalChannel):
@property
def connection_partner_type(self):
class OutputSignal(SignalChannel["InputSignal"]):

@classmethod
def connection_partner_type(cls) -> type[InputSignal]:
return InputSignal

def __call__(self) -> None:
Expand Down
Loading

0 comments on commit c594760

Please sign in to comment.