Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chain optimizations for Conditional nodes #2895

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,15 @@ def add_edge(self, edge: Edge):
get_or_add(self.__edges_by_source, edge.source.id, list).append(edge)
get_or_add(self.__edges_by_target, edge.target.id, list).append(edge)

def edges_from(self, source: NodeId) -> list[Edge]:
return self.__edges_by_source.get(source, [])
def edges_from(
self,
source: NodeId,
output_id: OutputId | None = None,
) -> list[Edge]:
edges = self.__edges_by_source.get(source, [])
if output_id is not None:
return [e for e in edges if e.source.output_id == output_id]
return edges

def edges_to(self, target: NodeId) -> list[Edge]:
return self.__edges_by_target.get(target, [])
Expand Down
103 changes: 91 additions & 12 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sanic.log import logger

from .chain import Chain
from api import InputId, OutputId

from .chain import Chain, Edge, Node


class _Mutation:
Expand All @@ -11,6 +13,38 @@ def signal(self) -> None:
self.changed = True


def __passthrough(
chain: Chain,
node: Node,
input_id: InputId,
output_id: OutputId = OutputId(0), # noqa: B008
):
"""
Rewires the chain such that the value of the given input is passed through to the given output.

This assumes that the node itself has no effect on the value.

Returns False if the input does not have a value or is not connected. True otherwise.
"""
in_edge = chain.edge_to(node.id, input_id)
if in_edge is not None:
# rewire
for e in chain.edges_from(node.id, output_id):
chain.remove_edge(e)
chain.add_edge(Edge(in_edge.source, e.target))
return True
else:
value = chain.inputs.get(node.id, input_id)
if value is not None:
# constant propagation
for e in chain.edges_from(node.id, output_id):
chain.remove_edge(e)
chain.inputs.set(e.target.id, e.target.input_id, value)
return True

return False


def __removed_dead_nodes(chain: Chain, mutation: _Mutation):
"""
If a node does not have side effects and has no downstream nodes, then it can be removed.
Expand All @@ -24,24 +58,68 @@ def __removed_dead_nodes(chain: Chain, mutation: _Mutation):
logger.debug(f"Chain optimization: Removed {node.schema_id} node {node.id}")


def __static_switch_trim(chain: Chain, mutation: _Mutation):
def __static_switch(chain: Chain, mutation: _Mutation):
"""
If the selected variant of the Switch node is statically known, then we can remove the input edges of all other variants.
If the selected variant of the Switch node is statically known (which should always be the case), then we can statically resolve and remove the Switch node.
"""

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:switch":
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
passed = False
for index, i in enumerate(node.data.inputs[1:]):
if index != value_index:
edge = chain.edge_to(node.id, i.id)
if edge is not None:
chain.remove_edge(edge)
mutation.signal()
logger.debug(
f"Chain optimization: Removed edge from {node.id} to {i.label}"
)
if index == value_index:
passed = __passthrough(chain, node, i.id)

if passed:
chain.remove_node(node.id)
mutation.signal()


def __useless_conditional(chain: Chain, mutation: _Mutation):
"""
Removes useless conditional nodes.
"""

if_true = InputId(1)
if_false = InputId(2)

def as_bool(value: object):
if isinstance(value, bool):
return value
if isinstance(value, int):
if value == 0:
return False
if value == 1:
return True
return None

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:conditional":
# the condition is a constant
const_condition = as_bool(chain.inputs.get(node.id, InputId(0)))
if const_condition is not None:
__passthrough(
chain,
node,
input_id=if_true if const_condition else if_false,
)
chain.remove_node(node.id)
mutation.signal()
continue

# identical true and false branches
true_edge = chain.edge_to(node.id, if_true)
false_edge = chain.edge_to(node.id, if_false)
if (
true_edge is not None
and false_edge is not None
and true_edge.source == false_edge.source
):
__passthrough(chain, node, if_true)
chain.remove_node(node.id)
mutation.signal()


def optimize(chain: Chain):
Expand All @@ -50,7 +128,8 @@ def optimize(chain: Chain):
mutation = _Mutation()

__removed_dead_nodes(chain, mutation)
__static_switch_trim(chain, mutation)
__static_switch(chain, mutation)
__useless_conditional(chain, mutation)

if not mutation.changed:
break
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
icon="BsShuffle",
inputs=[
BoolInput("Condition", default=True, has_handle=True).with_id(0),
AnyInput(label="If True").make_lazy(),
AnyInput(label="If False").make_lazy(),
AnyInput(label="If True").with_id(1).make_lazy(),
AnyInput(label="If False").with_id(2).make_lazy(),
],
outputs=[
BaseOutput(
Expand Down
Loading