From 056f3c985bf68c7a601f85d1257efc03095dad72 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Tue, 21 May 2024 17:27:21 +0200 Subject: [PATCH] Add chain optimizations for Conditional nodes --- backend/src/chain/chain.py | 11 +- backend/src/chain/optimize.py | 103 ++++++++++++++++-- .../utility/value/conditional.py | 4 +- 3 files changed, 102 insertions(+), 16 deletions(-) diff --git a/backend/src/chain/chain.py b/backend/src/chain/chain.py index cefa9e4e0..0622dfcf3 100644 --- a/backend/src/chain/chain.py +++ b/backend/src/chain/chain.py @@ -100,8 +100,15 @@ def add_edge(self, edge: Edge): get_or_add(self.__edges_by_source, edge.source.id, list).append(edge) get_or_add(self.__edges_by_target, edge.target.id, list).append(edge) - def edges_from(self, source: NodeId) -> list[Edge]: - return self.__edges_by_source.get(source, []) + def edges_from( + self, + source: NodeId, + output_id: OutputId | None = None, + ) -> list[Edge]: + edges = self.__edges_by_source.get(source, []) + if output_id is not None: + return [e for e in edges if e.source.output_id == output_id] + return edges def edges_to(self, target: NodeId) -> list[Edge]: return self.__edges_by_target.get(target, []) diff --git a/backend/src/chain/optimize.py b/backend/src/chain/optimize.py index 6c2e614ad..005db7375 100644 --- a/backend/src/chain/optimize.py +++ b/backend/src/chain/optimize.py @@ -1,6 +1,8 @@ from sanic.log import logger -from .chain import Chain +from api import InputId, OutputId + +from .chain import Chain, Edge, Node class _Mutation: @@ -11,6 +13,38 @@ def signal(self) -> None: self.changed = True +def __passthrough( + chain: Chain, + node: Node, + input_id: InputId, + output_id: OutputId = OutputId(0), # noqa: B008 +): + """ + Rewires the chain such that the value of the given input is passed through to the given output. + + This assumes that the node itself has no effect on the value. + + Returns False if the input does not have a value or is not connected. True otherwise. + """ + in_edge = chain.edge_to(node.id, input_id) + if in_edge is not None: + # rewire + for e in chain.edges_from(node.id, output_id): + chain.remove_edge(e) + chain.add_edge(Edge(in_edge.source, e.target)) + return True + else: + value = chain.inputs.get(node.id, input_id) + if value is not None: + # constant propagation + for e in chain.edges_from(node.id, output_id): + chain.remove_edge(e) + chain.inputs.set(e.target.id, e.target.input_id, value) + return True + + return False + + def __removed_dead_nodes(chain: Chain, mutation: _Mutation): """ If a node does not have side effects and has no downstream nodes, then it can be removed. @@ -24,24 +58,68 @@ def __removed_dead_nodes(chain: Chain, mutation: _Mutation): logger.debug(f"Chain optimization: Removed {node.schema_id} node {node.id}") -def __static_switch_trim(chain: Chain, mutation: _Mutation): +def __static_switch(chain: Chain, mutation: _Mutation): """ - If the selected variant of the Switch node is statically known, then we can remove the input edges of all other variants. + If the selected variant of the Switch node is statically known (which should always be the case), then we can statically resolve and remove the Switch node. """ for node in list(chain.nodes.values()): if node.schema_id == "chainner:utility:switch": value_index = chain.inputs.get(node.id, node.data.inputs[0].id) if isinstance(value_index, int): + passed = False for index, i in enumerate(node.data.inputs[1:]): - if index != value_index: - edge = chain.edge_to(node.id, i.id) - if edge is not None: - chain.remove_edge(edge) - mutation.signal() - logger.debug( - f"Chain optimization: Removed edge from {node.id} to {i.label}" - ) + if index == value_index: + passed = __passthrough(chain, node, i.id) + + if passed: + chain.remove_node(node.id) + mutation.signal() + + +def __useless_conditional(chain: Chain, mutation: _Mutation): + """ + Removes useless conditional nodes. + """ + + if_true = InputId(1) + if_false = InputId(2) + + def as_bool(value: object): + if isinstance(value, bool): + return value + if isinstance(value, int): + if value == 0: + return False + if value == 1: + return True + return None + + for node in list(chain.nodes.values()): + if node.schema_id == "chainner:utility:conditional": + # the condition is a constant + const_condition = as_bool(chain.inputs.get(node.id, InputId(0))) + if const_condition is not None: + __passthrough( + chain, + node, + input_id=if_true if const_condition else if_false, + ) + chain.remove_node(node.id) + mutation.signal() + continue + + # identical true and false branches + true_edge = chain.edge_to(node.id, if_true) + false_edge = chain.edge_to(node.id, if_false) + if ( + true_edge is not None + and false_edge is not None + and true_edge.source == false_edge.source + ): + __passthrough(chain, node, if_true) + chain.remove_node(node.id) + mutation.signal() def optimize(chain: Chain): @@ -50,7 +128,8 @@ def optimize(chain: Chain): mutation = _Mutation() __removed_dead_nodes(chain, mutation) - __static_switch_trim(chain, mutation) + __static_switch(chain, mutation) + __useless_conditional(chain, mutation) if not mutation.changed: break diff --git a/backend/src/packages/chaiNNer_standard/utility/value/conditional.py b/backend/src/packages/chaiNNer_standard/utility/value/conditional.py index 42d47f704..f9cdc3cd0 100644 --- a/backend/src/packages/chaiNNer_standard/utility/value/conditional.py +++ b/backend/src/packages/chaiNNer_standard/utility/value/conditional.py @@ -14,8 +14,8 @@ icon="BsShuffle", inputs=[ BoolInput("Condition", default=True, has_handle=True).with_id(0), - AnyInput(label="If True").make_lazy(), - AnyInput(label="If False").make_lazy(), + AnyInput(label="If True").with_id(1).make_lazy(), + AnyInput(label="If False").with_id(2).make_lazy(), ], outputs=[ BaseOutput(