Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber committed Jan 8, 2025
1 parent e13caf5 commit 4d242b6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
16 changes: 6 additions & 10 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ class ChannelConnectionError(ChannelError):


class Channel(
HasChannel,
HasLabel,
HasStateDisplay,
typing.Generic[ConnectionPartner],
ABC
HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC
):
"""
Channels facilitate the flow of information (data or control signals) into and
Expand Down Expand Up @@ -173,7 +169,7 @@ def _connection_partner_failure_message(self, other: ConnectionPartner) -> str:
)

def disconnect(
self, *others: ConnectionPartner
self, *others: ConnectionPartner
) -> list[tuple[Self, ConnectionPartner]]:
"""
If currently connected to any others, removes this and the other from eachothers
Expand All @@ -194,9 +190,7 @@ def disconnect(
destroyed_connections.append((self, other))
return destroyed_connections

def disconnect_all(
self
) -> list[tuple[Self, ConnectionPartner]]:
def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]:
"""
Disconnect from all other channels currently in the connections list.
"""
Expand Down Expand Up @@ -273,6 +267,7 @@ def __bool__(self):

DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel")


class DataChannel(Channel[DataConnectionPartner], ABC):
"""
Data channels control the flow of data on the graph.
Expand Down Expand Up @@ -488,7 +483,7 @@ def _both_typed(self, other: DataConnectionPartner | Self) -> bool:
return self._has_hint and other._has_hint

def _figure_out_who_is_who(
self, other: DataConnectionPartner
self, other: DataConnectionPartner
) -> tuple[OutputData, InputData]:
if isinstance(self, InputData) and isinstance(other, OutputData):
return other, self
Expand Down Expand Up @@ -575,6 +570,7 @@ def connection_partner_type(cls) -> type[InputData]:
"SignalConnectionPartner", bound="SignalChannel"
)


class SignalChannel(Channel[SignalConnectionPartner], ABC):
"""
Signal channels give the option control execution flow by triggering callback
Expand Down
3 changes: 2 additions & 1 deletion pyiron_workflow/type_hinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def valid_value(value, type_hint) -> bool:

def type_hint_to_tuple(type_hint) -> tuple:
if isinstance(
type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore
type_hint,
types.UnionType | typing._UnionGenericAlias, # type: ignore
# mypy complains because it thinks typing._UnionGenericAlias doesn't exist
# It definitely does, and we may be able to remove this once mypy catches up
):
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def data_input_locked(self):

class DummyChannel(Channel[ConnectionPartner]):
"""Just to de-abstract the base class"""

def __str__(self):
return "non-abstract input"

Expand Down Expand Up @@ -403,7 +404,7 @@ def test_aggregating_call(self):
self.assertEqual(
signals_sent,
len(agg.received_signals),
msg="Sanity check on initial conditions"
msg="Sanity check on initial conditions",
)
self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions")

Expand All @@ -412,22 +413,22 @@ def test_aggregating_call(self):
self.assertListEqual(
[0],
owner.foo,
msg="Aggregating calls should only matter when they come from a connection"
msg="Aggregating calls should only matter when they come from a connection",
)
agg(out_unrelated)
signals_sent += 1
self.assertListEqual(
[0],
owner.foo,
msg="Aggregating calls should only matter when they come from a connection"
msg="Aggregating calls should only matter when they come from a connection",
)

self.out()
signals_sent += 1
self.assertEqual(
signals_sent,
len(agg.received_signals),
msg="Signals from other channels should be received"
msg="Signals from other channels should be received",
)
self.assertListEqual(
[0],
Expand Down

0 comments on commit 4d242b6

Please sign in to comment.