Skip to content

Commit

Permalink
Bugfix/flows with multiple starts plus ands breaking (#1531)
Browse files Browse the repository at this point in the history
* bugfix/flows-with-multiple-starts-plus-ands-breaking

* fix user found issue

* remove prints
  • Loading branch information
bhancockio authored Oct 29, 2024
1 parent b43f398 commit cdfbd5f
Showing 1 changed file with 49 additions and 26 deletions.
75 changes: 49 additions & 26 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# flow.py

import asyncio
import inspect
from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
Expand Down Expand Up @@ -120,6 +118,8 @@ def __new__(mcs, name, bases, dct):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)

# TODO: should we add a check for __condition_type__ 'AND'?
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
Expand Down Expand Up @@ -159,7 +159,8 @@ class _FlowGeneric(cls): # type: ignore
def __init__(self) -> None:
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._completed_methods: Set[str] = set()
self._executed_methods: Set[str] = set()
self._scheduled_tasks: Set[str] = set()
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs

Expand Down Expand Up @@ -216,50 +217,65 @@ async def kickoff_async(self) -> Any:
else:
return None # Or raise an exception if no methods were executed

async def _execute_start_method(self, start_method: str) -> None:
result = await self._execute_method(self._methods[start_method])
await self._execute_listeners(start_method, result)
async def _execute_start_method(self, start_method_name: str) -> None:
result = await self._execute_method(
start_method_name, self._methods[start_method_name]
)
await self._execute_listeners(start_method_name, result)

async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any:
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
result = (
await method(*args, **kwargs)
if asyncio.iscoroutinefunction(method)
else method(*args, **kwargs)
)
self._method_outputs.append(result) # Store the output

self._executed_methods.add(method_name)

return result

async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = []

if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(router_method)
path = await self._execute_method(
trigger_method, router_method
) # TODO: Change or not?
# Use the path as the new trigger method
trigger_method = path

for listener, (condition_type, methods) in self._listeners.items():
for listener_name, (condition_type, methods) in self._listeners.items():
if condition_type == "OR":
if trigger_method in methods:
listener_tasks.append(
self._execute_single_listener(listener, result)
)
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND":
if listener not in self._pending_and_listeners:
self._pending_and_listeners[listener] = set()
self._pending_and_listeners[listener].add(trigger_method)
if set(methods) == self._pending_and_listeners[listener]:
listener_tasks.append(
self._execute_single_listener(listener, result)
)
del self._pending_and_listeners[listener]
if all(method in self._executed_methods for method in methods):
if (
listener_name not in self._executed_methods
and listener_name not in self._scheduled_tasks
):
self._scheduled_tasks.add(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)

# Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks)

async def _execute_single_listener(self, listener: str, result: Any) -> None:
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
method = self._methods[listener]
method = self._methods[listener_name]
sig = inspect.signature(method)
params = list(sig.parameters.values())

Expand All @@ -268,15 +284,22 @@ async def _execute_single_listener(self, listener: str, result: Any) -> None:

if method_params:
# If listener expects parameters, pass the result
listener_result = await self._execute_method(method, result)
listener_result = await self._execute_method(
listener_name, method, result
)
else:
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(method)
listener_result = await self._execute_method(listener_name, method)

# Remove from scheduled tasks after execution
self._scheduled_tasks.discard(listener_name)

# Execute listeners of this listener
await self._execute_listeners(listener, listener_result)
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
print(f"[Flow._execute_single_listener] Error in method {listener}: {e}")
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
)
import traceback

traceback.print_exc()
Expand Down

0 comments on commit cdfbd5f

Please sign in to comment.