diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index a8c29890..0481bcdd 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -326,7 +326,7 @@ def remove_child(self, child: Node | str) -> Node: def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: + ) -> tuple[Node, Node]: """ Replaces a node currently owned with a new node instance. The replacement must not belong to any other parent or have any connections. @@ -348,7 +348,7 @@ def replace_child( and simply gets instantiated.) Returns: - (Node): The node that got removed + (Node, Node): The node that got removed and the new one that replaced it. """ if isinstance(owned_node, str): owned_node = self.children[owned_node] @@ -367,15 +367,18 @@ def replace_child( ) if replacement.connected: raise ValueError("Replacement node must not have any connections") + replacement_node = replacement elif issubclass(replacement, Node): - replacement = replacement(label=owned_node.label) + replacement_node = replacement(label=owned_node.label) else: raise TypeError( f"Expected replacement node to be a node instance or node subclass, but " f"got {replacement}" ) - replacement.copy_io(owned_node) # If the replacement is incompatible, we'll + replacement_node.copy_io( + owned_node + ) # If the replacement is incompatible, we'll # fail here before we've changed the parent at all. Since the replacement was # first guaranteed to be an unconnected orphan, there is not yet any permanent # damage @@ -388,23 +391,29 @@ def replace_child( if sending_channel.value_receiver in owned_node.inputs ] outbound_links = [ - (replacement.outputs[sending_channel.label], sending_channel.value_receiver) + ( + replacement_node.outputs[sending_channel.label], + sending_channel.value_receiver, + ) for sending_channel in owned_node.outputs if sending_channel.value_receiver in self.outputs ] self.remove_child(owned_node) - replacement.label, owned_node.label = owned_node.label, replacement.label - self.add_child(replacement) + replacement_node.label, owned_node.label = ( + owned_node.label, + replacement_node.label, + ) + self.add_child(replacement_node) if is_starting_node: - self.starting_nodes.append(replacement) + self.starting_nodes.append(replacement_node) for sending_channel, receiving_channel in inbound_links + outbound_links: sending_channel.value_receiver = receiving_channel # Clear caches self._cached_inputs = None - replacement._cached_inputs = None + replacement_node._cached_inputs = None - return owned_node + return owned_node, replacement_node def executor_shutdown(self, wait=True, *, cancel_futures=False): """ diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index d7a3fe53..ced1e533 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -219,7 +219,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC): >>> # With the replace method >>> # (replacement target can be specified by label or instance, >>> # the replacing node can be specified by instance or class) - >>> replaced = adds_six_macro.replace_child(adds_six_macro.one, add_two()) + >>> replaced, _ = adds_six_macro.replace_child(adds_six_macro.one, add_two()) >>> # With the replace_with method >>> adds_six_macro.two.replace_with(add_two()) >>> # And by assignment of a compatible class to an occupied node label diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 3094683c..aa39caff 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -497,8 +497,10 @@ def _owned_io_panels(self) -> list[IO]: def replace_child( self, owned_node: Node | str, replacement: Node | type[Node] - ) -> Node: - super().replace_child(owned_node=owned_node, replacement=replacement) + ) -> tuple[Node, Node]: + replaced, replacement_node = super().replace_child( + owned_node=owned_node, replacement=replacement + ) # Finally, make sure the IO is constructible with this new node, which will # catch things like incompatible IO maps @@ -509,11 +511,11 @@ def replace_child( except Exception as e: # If IO can't be successfully rebuilt using this node, revert changes and # raise the exception - self.replace_child(replacement, owned_node) # Guaranteed to work since + self.replace_child(replacement_node, replaced) # Guaranteed to work since # replacement in the other direction was already a success raise e - return owned_node + return replaced, replacement_node @property def parent(self) -> None: