Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize static switches #2477

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, TypeVar, Union

from api import InputId, NodeData, NodeId, OutputId, registry
Expand Down Expand Up @@ -52,27 +53,42 @@ def has_side_effects(self) -> bool:
Node = Union[FunctionNode, NewIteratorNode, CollectorNode]


@dataclass(frozen=True)
class EdgeSource:
def __init__(self, node_id: NodeId, output_id: OutputId):
self.id: NodeId = node_id
self.output_id: OutputId = output_id
id: NodeId
output_id: OutputId


@dataclass(frozen=True)
class EdgeTarget:
def __init__(self, node_id: NodeId, input_id: InputId):
self.id: NodeId = node_id
self.input_id: InputId = input_id
id: NodeId
input_id: InputId


@dataclass(frozen=True)
class Edge:
def __init__(self, source: EdgeSource, target: EdgeTarget):
self.source = source
self.target = target
source: EdgeSource
target: EdgeTarget


class ChainInputs:
def __init__(self) -> None:
self.inputs: dict[NodeId, dict[InputId, object]] = {}

def get(self, node_id: NodeId, input_id: InputId) -> object | None:
node = self.inputs.get(node_id)
if node is None:
return None
return node.get(input_id)

def set(self, node_id: NodeId, input_id: InputId, value: object) -> None:
get_or_add(self.inputs, node_id, dict)[input_id] = value


class Chain:
def __init__(self):
self.nodes: dict[NodeId, Node] = {}
self.inputs: ChainInputs = ChainInputs()
self.__edges_by_source: dict[NodeId, list[Edge]] = {}
self.__edges_by_target: dict[NodeId, list[Edge]] = {}

Expand All @@ -90,12 +106,24 @@ def edges_from(self, source: NodeId) -> list[Edge]:
def edges_to(self, target: NodeId) -> list[Edge]:
return self.__edges_by_target.get(target, [])

def edge_to(self, target: NodeId, input_id: InputId) -> Edge | None:
"""
Returns the edge connected to the given input (if any).
"""
edges = self.__edges_by_target.get(target)
if edges is not None:
for e in edges:
if e.target.input_id == input_id:
return e
return None

def remove_node(self, node_id: NodeId):
"""
Removes the node with the given id.
If the node is an iterator node, then all of its children will also be removed.
"""

self.inputs.inputs.pop(node_id, None)
node = self.nodes.pop(node_id, None)
if node is None:
return
Expand All @@ -105,6 +133,17 @@ def remove_node(self, node_id: NodeId):
for e in self.__edges_by_target.pop(node_id, []):
self.__edges_by_source[e.source.id].remove(e)

def remove_edge(self, edge: Edge) -> None:
"""
Removes the edge connected to the given input (if any).
"""
edges_target = self.__edges_by_target.get(edge.target.id)
if edges_target is not None:
edges_target.remove(edge)
edges_source = self.__edges_by_source.get(edge.source.id)
if edges_source is not None:
edges_source.remove(edge)

def topological_order(self) -> list[NodeId]:
"""
Returns all nodes in topological order.
Expand Down
63 changes: 37 additions & 26 deletions backend/src/chain/input.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,60 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Union

from api import NodeId
from api import NodeData, NodeId, OutputId

from .chain import Chain


@dataclass(frozen=True)
class EdgeInput:
def __init__(self, node_id: NodeId, index: int) -> None:
self.id = node_id
self.index = index
id: NodeId
index: int


@dataclass(frozen=True)
class ValueInput:
def __init__(self, value: object) -> None:
self.value: object = value
value: object


Input = Union[EdgeInput, ValueInput]


class InputMap:
def __init__(self, parent: InputMap | None = None) -> None:
self.__data: dict[NodeId, list[Input]] = {}
self.parent: InputMap | None = parent
def __init__(self) -> None:
self.data: dict[NodeId, list[Input]] = {}

def get(self, node_id: NodeId) -> list[Input]:
values = self.__data.get(node_id, None)
if values is not None:
return values
@staticmethod
def from_chain(chain: Chain) -> InputMap:
input_map = InputMap()

if self.parent:
return self.parent.get(node_id)
def get_output_index(data: NodeData, output_id: OutputId) -> int:
for i, output in enumerate(data.outputs):
if output.id == output_id:
return i
raise AssertionError(f"Unknown output id {output_id}")

raise AssertionError(f"Unknown node id {node_id}")
for node in chain.nodes.values():
inputs: list[Input] = []

def set(self, node_id: NodeId, values: list[Input]):
self.__data[node_id] = values
for i in node.data.inputs:
edge = chain.edge_to(node.id, i.id)
if edge is not None:
source = chain.nodes[edge.source.id]
output_index = get_output_index(source.data, edge.source.output_id)
inputs.append(EdgeInput(edge.source.id, output_index))
else:
inputs.append(ValueInput(chain.inputs.get(node.id, i.id)))

def set_values(self, node_id: NodeId, values: list[object]):
self.__data[node_id] = [ValueInput(x) for x in values]
input_map.data[node.id] = inputs

def set_append(self, node_id: NodeId, values: list[Input]):
inputs = [*self.get(node_id), *values]
self.set(node_id, inputs)
return input_map

def set_append_values(self, node_id: NodeId, values: list[object]):
inputs = [*self.get(node_id), *[ValueInput(x) for x in values]]
self.set(node_id, inputs)
def get(self, node_id: NodeId) -> list[Input]:
values = self.data.get(node_id, None)
if values is not None:
return values

raise AssertionError(f"Unknown node id {node_id}")
12 changes: 4 additions & 8 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
FunctionNode,
NewIteratorNode,
)
from .input import EdgeInput, Input, InputMap, ValueInput


class JsonEdgeInput(TypedDict):
Expand Down Expand Up @@ -48,9 +47,8 @@ def __init__(
self.to_index = to_index


def parse_json(json: list[JsonNode]) -> tuple[Chain, InputMap]:
def parse_json(json: list[JsonNode]) -> Chain:
chain = Chain()
input_map = InputMap()

index_edges: list[IndexEdge] = []

Expand All @@ -63,14 +61,12 @@ def parse_json(json: list[JsonNode]) -> tuple[Chain, InputMap]:
node = FunctionNode(json_node["id"], json_node["schemaId"])
chain.add_node(node)

inputs: list[Input] = []
inputs = node.data.inputs
for index, i in enumerate(json_node["inputs"]):
if i["type"] == "edge":
inputs.append(EdgeInput(i["id"], i["index"]))
index_edges.append(IndexEdge(i["id"], i["index"], node.id, index))
else:
inputs.append(ValueInput(i["value"]))
input_map.set(node.id, inputs)
chain.inputs.set(node.id, inputs[index].id, i["value"])

for index_edge in index_edges:
source_node = chain.nodes[index_edge.from_id].data
Expand All @@ -89,4 +85,4 @@ def parse_json(json: list[JsonNode]) -> tuple[Chain, InputMap]:
)
)

return chain, input_map
return chain
45 changes: 38 additions & 7 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,54 @@
from .chain import Chain


def __removed_dead_nodes(chain: Chain) -> bool:
class _Mutation:
def __init__(self) -> None:
self.changed = False

def signal(self) -> None:
self.changed = True


def __removed_dead_nodes(chain: Chain, mutation: _Mutation):
"""
If a node does not have side effects and has no downstream nodes, then it can be removed.
"""
changed = False

for node in list(chain.nodes.values()):
is_dead = len(chain.edges_from(node.id)) == 0 and not node.has_side_effects()
if is_dead:
chain.remove_node(node.id)
changed = True
mutation.signal()
logger.debug(f"Chain optimization: Removed {node.schema_id} node {node.id}")

return changed

def __static_switch_trim(chain: Chain, mutation: _Mutation):
"""
If the selected variant of the Switch node is statically known, then we can remove the input edges of all other variants.
"""

for node in list(chain.nodes.values()):
if node.schema_id == "chainner:utility:switch":
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
value_index = chain.inputs.get(node.id, node.data.inputs[0].id)
if isinstance(value_index, int):
for index, i in enumerate(node.data.inputs[1:]):
if index != value_index:
edge = chain.edge_to(node.id, i.id)
if edge is not None:
chain.remove_edge(edge)
mutation.signal()
logger.debug(
f"Chain optimization: Removed edge from {node.id} to {i.label}"
)


def optimize(chain: Chain):
changed = True
while changed:
changed = __removed_dead_nodes(chain)
max_passes = 10
for _ in range(max_passes):
mutation = _Mutation()

__removed_dead_nodes(chain, mutation)
__static_switch_trim(chain, mutation)

if not mutation.changed:
break
3 changes: 1 addition & 2 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def __init__(
self,
id: ExecutionId,
chain: Chain,
inputs: InputMap,
send_broadcast_data: bool,
options: ExecutionOptions,
loop: asyncio.AbstractEventLoop,
Expand All @@ -323,7 +322,7 @@ def __init__(
):
self.id: ExecutionId = id
self.chain = chain
self.inputs = inputs
self.inputs: InputMap = InputMap.from_chain(chain)
self.send_broadcast_data: bool = send_broadcast_data
self.options: ExecutionOptions = options
self.cache: OutputCache[NodeOutput] = OutputCache(parent=parent_cache)
Expand Down
9 changes: 3 additions & 6 deletions backend/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)
from chain.cache import OutputCache
from chain.chain import Chain, FunctionNode
from chain.input import InputMap
from chain.json import JsonNode, parse_json
from chain.optimize import optimize
from custom_types import UpdateProgressFn
Expand Down Expand Up @@ -168,14 +167,13 @@ async def run(request: Request):

full_data: RunRequest = dict(request.json) # type: ignore
logger.debug(full_data)
chain, inputs = parse_json(full_data["data"])
chain = parse_json(full_data["data"])
optimize(chain)

logger.info("Running new executor...")
executor = Executor(
id=ExecutionId("main-executor " + uuid.uuid4().hex),
chain=chain,
inputs=inputs,
send_broadcast_data=full_data["sendBroadcastData"],
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
Expand Down Expand Up @@ -236,8 +234,8 @@ async def run_individual(request: Request):
chain = Chain()
chain.add_node(node)

input_map = InputMap()
input_map.set_values(node_id, full_data["inputs"])
for index, i in enumerate(full_data["inputs"]):
chain.inputs.set(node_id, node.data.inputs[index].id, i)

# only yield certain types of events
queue = EventConsumer.filter(
Expand All @@ -248,7 +246,6 @@ async def run_individual(request: Request):
executor = Executor(
id=execution_id,
chain=chain,
inputs=input_map,
send_broadcast_data=True,
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
Expand Down