From 4d242b6fd8554b026917623eb0222a40a4d8fa8d Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 13:03:44 -0800 Subject: [PATCH] black Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 16 ++++++---------- pyiron_workflow/type_hinting.py | 3 ++- tests/unit/test_channels.py | 9 +++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index a95f3806..a7883326 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -38,11 +38,7 @@ class ChannelConnectionError(ChannelError): class Channel( - HasChannel, - HasLabel, - HasStateDisplay, - typing.Generic[ConnectionPartner], - ABC + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC ): """ Channels facilitate the flow of information (data or control signals) into and @@ -173,7 +169,7 @@ def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: ) def disconnect( - self, *others: ConnectionPartner + self, *others: ConnectionPartner ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers @@ -194,9 +190,7 @@ def disconnect( destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all( - self - ) -> list[tuple[Self, ConnectionPartner]]: + def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -273,6 +267,7 @@ def __bool__(self): DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") + class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -488,7 +483,7 @@ def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataConnectionPartner + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: if isinstance(self, InputData) and isinstance(other, OutputData): return other, self @@ -575,6 +570,7 @@ def connection_partner_type(cls) -> type[InputData]: "SignalConnectionPartner", bound="SignalChannel" ) + class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 563bfe2f..28af408b 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -29,7 +29,8 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: if isinstance( - type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore + type_hint, + types.UnionType | typing._UnionGenericAlias, # type: ignore # mypy complains because it thinks typing._UnionGenericAlias doesn't exist # It definitely does, and we may be able to remove this once mypy catches up ): diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index 151a97fa..dce71e3c 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -35,6 +35,7 @@ def data_input_locked(self): class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" + def __str__(self): return "non-abstract input" @@ -403,7 +404,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Sanity check on initial conditions" + msg="Sanity check on initial conditions", ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") @@ -412,14 +413,14 @@ def test_aggregating_call(self): self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) agg(out_unrelated) signals_sent += 1 self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) self.out() @@ -427,7 +428,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Signals from other channels should be received" + msg="Signals from other channels should be received", ) self.assertListEqual( [0],