Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

feat: use full connection data to route I/O #148

Merged
merged 30 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f693855
fix sample components
ZanSara Oct 25, 2023
c189793
make sum variadic
ZanSara Oct 25, 2023
3f3dd0a
separate queue and buffer
ZanSara Oct 25, 2023
591bb7e
all works but loops & variadics together
ZanSara Oct 26, 2023
4349077
fix some tests
ZanSara Oct 26, 2023
6840921
fix some tests
ZanSara Oct 26, 2023
66c6cdc
all tests green
ZanSara Oct 26, 2023
c751321
clean up code a bit
ZanSara Oct 26, 2023
6694c7e
refactor code
ZanSara Oct 26, 2023
20f486b
Merge branch 'main' into no-nones-for-skipping
ZanSara Oct 26, 2023
8ae5547
fix tests
ZanSara Oct 26, 2023
00893af
fix self loops
ZanSara Oct 27, 2023
d99b753
fix reused sockets bug
ZanSara Oct 27, 2023
d4d5387
add distinct loops
ZanSara Oct 27, 2023
4b71091
add distinct loops test
ZanSara Oct 27, 2023
606cb1a
break out some code from run()
ZanSara Oct 27, 2023
b419a64
docstring
ZanSara Oct 27, 2023
f457188
improve variadics drawing
ZanSara Nov 4, 2023
7ed1667
black
ZanSara Nov 4, 2023
0781653
document the deepcopy
masci Nov 14, 2023
f0468c1
Merge branch 'main' into no-nones-for-skipping
masci Nov 14, 2023
a5f8735
re-arrange connection dataclass and add tests
masci Nov 14, 2023
cd8592a
consumer -> receiver
masci Nov 14, 2023
ba9bfbe
fix typing
masci Nov 14, 2023
7ae0248
move Connection-related code under component package
masci Nov 14, 2023
866ea36
clean up connect()
masci Nov 14, 2023
9ef1d03
cosmetics and typing
masci Nov 14, 2023
b8d0d12
fix linter, make Connection a dataclass again
masci Nov 14, 2023
49550fb
fix typing
masci Nov 14, 2023
cdb921b
add test case for #105
masci Nov 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ dmypy.json
# Canals
drafts/
.canals_debug/
test/**/*.png
test/**/*.png
14 changes: 11 additions & 3 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import inspect
from typing import Protocol, runtime_checkable, Any
from types import new_class
from copy import deepcopy

from canals.component.sockets import InputSocket, OutputSocket
from canals.errors import ComponentError
Expand Down Expand Up @@ -121,10 +122,16 @@ def __call__(cls, *args, **kwargs):
# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets

# If the __init__ called component.set_output_types(), __canals_output__ is already populated
# If `component.set_output_types()` was called in the component constructor,
# `__canals_output__` is already populated, no need to do anything.
if not hasattr(instance, "__canals_output__"):
# if the run method was decorated, it has a _output_types_cache field assigned
instance.__canals_output__ = getattr(instance.run, "_output_types_cache", {})
# If that's not the case, we need to populate `__canals_output__`
#
# If the `run` method was decorated, it has a `_output_types_cache` field assigned
# that stores the output specification.
# We deepcopy the content of the cache to transfer ownership from the class method
# to the actual instance, so that different instances of the same class won't share this data.
instance.__canals_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))

# If the __init__ called component.set_input_types(), __canals_input__ is already populated
if not hasattr(instance, "__canals_input__"):
Expand All @@ -134,6 +141,7 @@ def __call__(cls, *args, **kwargs):
param: InputSocket(
name=param,
type=run_signature.parameters[param].annotation,
is_mandatory=run_signature.parameters[param].default == inspect.Parameter.empty,
)
for param in list(run_signature.parameters)[1:] # First is 'self' and it doesn't matter.
}
Expand Down
167 changes: 167 additions & 0 deletions canals/component/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import itertools
from typing import Optional, List, Tuple
from dataclasses import dataclass

from canals.component.sockets import InputSocket, OutputSocket
from canals.type_utils import _type_name, _types_are_compatible
from canals.errors import PipelineConnectError


@dataclass
class Connection:
sender: Optional[str]
sender_socket: Optional[OutputSocket]
receiver: Optional[str]
receiver_socket: Optional[InputSocket]

def __post_init__(self):
if self.sender and self.sender_socket and self.receiver and self.receiver_socket:
# Make sure the receiving socket isn't already connected, unless it's variadic. Sending sockets can be
# connected as many times as needed, so they don't need this check
if self.receiver_socket.senders and not self.receiver_socket.is_variadic:
raise PipelineConnectError(
f"Cannot connect '{self.sender}.{self.sender_socket.name}' with '{self.receiver}.{self.receiver_socket.name}': "
f"{self.receiver}.{self.receiver_socket.name} is already connected to {self.receiver_socket.senders}.\n"
)

self.sender_socket.receivers.append(self.receiver)
self.receiver_socket.senders.append(self.sender)

def __repr__(self):
if self.sender and self.sender_socket:
sender_repr = f"{self.sender}.{self.sender_socket.name} ({_type_name(self.sender_socket.type)})"
else:
sender_repr = "input needed"

if self.receiver and self.receiver_socket:
receiver_repr = f"({_type_name(self.receiver_socket.type)}) {self.receiver}.{self.receiver_socket.name}"
else:
receiver_repr = "output"

return f"{sender_repr} --> {receiver_repr}"

def __hash__(self):
"""
Connection is used as a dictionary key in Pipeline, it must be hashable
"""
return hash(
"-".join(
[
self.sender if self.sender else "input",
self.sender_socket.name if self.sender_socket else "",
self.receiver if self.receiver else "output",
self.receiver_socket.name if self.receiver_socket else "",
]
)
)

@property
def is_mandatory(self) -> bool:
"""
Returns True if the connection goes to a mandatory input socket, False otherwise
"""
if self.receiver_socket:
return self.receiver_socket.is_mandatory
return False

@staticmethod
def from_list_of_sockets(
sender_node: str, sender_sockets: List[OutputSocket], receiver_node: str, receiver_sockets: List[InputSocket]
) -> "Connection":
"""
Find one single possible connection between two lists of sockets.
"""
# List all sender/receiver combinations of sockets that match by type
possible_connections = [
(sender_sock, receiver_sock)
for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets)
if _types_are_compatible(sender_sock.type, receiver_sock.type)
]

# No connections seem to be possible
if not possible_connections:
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)

# Both sockets were specified: explain why the types don't match
if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1:
raise PipelineConnectError(
f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': "
f"their declared input and output types do not match.\n{connections_status_str}"
)

# Not both sockets were specified: explain there's no possible match on any pair
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{sender_node}' with '{receiver_node}': "
f"no matching connections available.\n{connections_status_str}"
)

# There's more than one possible connection
if len(possible_connections) > 1:
# Try to match by name
name_matches = [
(out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
]
if len(name_matches) != 1:
# TODO allow for multiple connections at once if there is no ambiguity?
# TODO give priority to sockets that have no default values?
connections_status_str = _connections_status(
sender_node=sender_node,
sender_sockets=sender_sockets,
receiver_node=receiver_node,
receiver_sockets=receiver_sockets,
)
raise PipelineConnectError(
f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible "
"between these components. Please specify the connection name, like: "
f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', "
f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}"
)

match = possible_connections[0]
return Connection(sender_node, match[0], receiver_node, match[1])


def _connections_status(
sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
):
"""
Lists the status of the sockets, for error messages.
"""
sender_sockets_entries = []
for sender_socket in sender_sockets:
sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
sender_sockets_list = "\n".join(sender_sockets_entries)

receiver_sockets_entries = []
for receiver_socket in receiver_sockets:
if receiver_socket.senders:
sender_status = f"sent by {','.join(receiver_socket.senders)}"
else:
sender_status = "available"
receiver_sockets_entries.append(
f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
)
receiver_sockets_list = "\n".join(receiver_sockets_entries)

return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"


def parse_connect_string(connection: str) -> Tuple[str, Optional[str]]:
"""
Returns component-connection pairs from a connect_to/from string
"""
if "." in connection:
split_str = connection.split(".", maxsplit=1)
return (split_str[0], split_str[1])
return connection, None
8 changes: 4 additions & 4 deletions canals/component/sockets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import get_origin, get_args, List, Type, Union
from typing import get_args, List, Type
import logging
from dataclasses import dataclass, field

Expand All @@ -15,12 +15,11 @@
class InputSocket:
name: str
type: Type
is_optional: bool = field(init=False)
is_mandatory: bool = True
is_variadic: bool = field(init=False)
sender: List[str] = field(default_factory=list)
senders: List[str] = field(default_factory=list)

def __post_init__(self):
self.is_optional = get_origin(self.type) is Union and type(None) in get_args(self.type)
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION
Expand All @@ -39,3 +38,4 @@ def __post_init__(self):
class OutputSocket:
name: str
type: type
receivers: List[str] = field(default_factory=list)
115 changes: 0 additions & 115 deletions canals/pipeline/connections.py

This file was deleted.

11 changes: 5 additions & 6 deletions canals/pipeline/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@ def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSo
input sockets, including all such sockets with default values.
"""
return {
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.sender]
name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic]
for name, data in graph.nodes(data=True)
}


def find_pipeline_outputs(graph) -> Dict[str, List[OutputSocket]]:
def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]:
"""
Collect components that have disconnected output sockets. They define the pipeline output.
"""
return {
node: list(data.get("output_sockets", {}).values())
for node, data in graph.nodes(data=True)
if not graph.out_edges(node)
name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers]
for name, data in graph.nodes(data=True)
}


Expand All @@ -40,7 +39,7 @@ def describe_pipeline_inputs(graph: networkx.MultiDiGraph):
Returns a dictionary with the input names and types that this pipeline accepts.
"""
inputs = {
comp: {socket.name: {"type": socket.type, "is_optional": socket.is_optional} for socket in data}
comp: {socket.name: {"type": socket.type, "is_mandatory": socket.is_mandatory} for socket in data}
for comp, data in find_pipeline_inputs(graph).items()
if data
}
Expand Down
2 changes: 1 addition & 1 deletion canals/pipeline/draw/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.sender and not in_socket.is_optional:
if not in_socket.senders and in_socket.is_mandatory:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
Expand Down
Loading