diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index af602a31..f96bb48c 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -118,12 +118,6 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - @abstractmethod - def _valid_connection(self, other: object) -> bool: - """ - Logic for determining if a connection is valid. - """ - def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. @@ -146,22 +140,30 @@ def connect(self, *others: ConjugateType) -> None: for other in others: if other in self.connections: continue - elif self._valid_connection(other): - # Prepend new connections - # so that connection searches run newest to oldest - self.connections.insert(0, other) - other.connections.insert(0, self) - else: - if isinstance(other, self.connection_conjugate()): + elif isinstance(other, self.connection_conjugate()): + if self._valid_connection(other): + # Prepend new connections + # so that connection searches run newest to oldest + self.connections.insert(0, other) + other.connections.insert(0, self) + else: raise ChannelConnectionError( self._connection_conjugate_failure_message(other) ) from None - else: - raise TypeError( - f"Can only connect to {self.connection_conjugate()} " - f"objects, but {self.full_label} ({self.__class__}) " - f"got {other} ({type(other)})" - ) + else: + raise TypeError( + f"Can only connect to {self.connection_conjugate()} " + f"objects, but {self.full_label} ({self.__class__}) " + f"got {other} ({type(other)})" + ) + + def _valid_connection(self, other: ConjugateType) -> bool: + """ + Logic for determining if a connection to a conjugate partner is valid. + + Override in child classes as necessary. + """ + return True def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: return ( @@ -466,21 +468,18 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: object) -> bool: - if isinstance(other, self.connection_conjugate()): - if self._both_typed(other): - out, inp = self._figure_out_who_is_who(other) - if not inp.strict_hints: - return True - else: - return type_hint_is_as_or_more_specific_than( - out.type_hint, inp.type_hint - ) - else: - # If either is untyped, don't do type checking + def _valid_connection(self, other: DataChannel) -> bool: + if self._both_typed(other): + out, inp = self._figure_out_who_is_who(other) + if not inp.strict_hints: return True + else: + return type_hint_is_as_or_more_specific_than( + out.type_hint, inp.type_hint + ) else: - return False + # If either is untyped, don't do type checking + return True def _connection_conjugate_failure_message(self, other: DataChannel) -> str: msg = super()._connection_conjugate_failure_message(other) @@ -599,9 +598,6 @@ class SignalChannel(FlavorChannel[SignalType], ABC): def __call__(self) -> None: pass - def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_conjugate()) - class BadCallbackError(ValueError): pass