From f2995ab5f55c8917b865a405fb9ffe99b70ff86d Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Wed, 17 Jul 2024 21:10:32 +0800 Subject: [PATCH] Add dynamic DAG (#317) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/mega/gateway.py | 36 +++++--- comps/cores/mega/orchestrator.py | 58 +++++++++--- comps/cores/proto/docarray.py | 8 +- .../langchain/guardrails_tgi_gaudi.py | 6 +- tests/cores/mega/test_aio.py | 6 +- tests/cores/mega/test_base_statistics.py | 2 +- tests/cores/mega/test_runtime_graph.py | 92 +++++++++++++++++++ tests/cores/mega/test_service_orchestrator.py | 2 +- .../test_service_orchestrator_with_gateway.py | 2 +- tests/test_workflow_chatqna.py | 5 +- 10 files changed, 181 insertions(+), 36 deletions(-) create mode 100644 tests/cores/mega/test_runtime_graph.py diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 40dda88aa..f1bfbabb6 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -119,7 +119,9 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - result_dict = await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -128,7 +130,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() @@ -161,7 +163,9 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -170,7 +174,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() @@ -208,7 +212,7 @@ async def handle_request(self, request: Request): ### Translated codes: """ prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code) - result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -217,7 +221,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() @@ -253,7 +257,7 @@ async def handle_request(self, request: Request): prompt = prompt_template.format( language_from=language_from, language_to=language_to, source_language=source_language ) - result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}) + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt}) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -262,7 +266,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() @@ -295,7 +299,9 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"query": prompt}, llm_parameters=parameters + ) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -304,7 +310,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() @@ -342,11 +348,11 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=False, # TODO add streaming LLM output as input to TTS ) - result_dict = await self.megaservice.schedule( + result_dict, runtime_graph = await self.megaservice.schedule( initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters ) - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["byte_str"] return response @@ -371,7 +377,9 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - result_dict = await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( @@ -380,7 +388,7 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] + last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index 956dbbbe9..723f0db5d 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import copy import json +import re from typing import Dict, List import aiohttp @@ -39,10 +41,16 @@ def flow_to(self, from_service, to_service): async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()): result_dict = {} + runtime_graph = DAG() + runtime_graph.graph = copy.deepcopy(self.graph) timeout = aiohttp.ClientTimeout(total=1000) async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session: - pending = {asyncio.create_task(self.execute(session, node, initial_inputs)) for node in self.ind_nodes()} + pending = { + asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph)) + for node in self.ind_nodes() + } + ind_nodes = self.ind_nodes() while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) @@ -51,13 +59,40 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa self.dump_outputs(node, response, result_dict) # traverse the current node's downstream nodes and execute if all one's predecessors are finished - downstreams = self.downstream(node) - for d_node in downstreams: - if all(i in result_dict for i in self.predecessors(d_node)): - inputs = self.process_outputs(self.predecessors(d_node), result_dict) - pending.add(asyncio.create_task(self.execute(session, d_node, inputs, llm_parameters))) + downstreams = runtime_graph.downstream(node) + + # remove all the black nodes that are skipped to be forwarded to + if not isinstance(response, StreamingResponse) and "downstream_black_list" in response: + for black_node in response["downstream_black_list"]: + for downstream in reversed(downstreams): + try: + if re.findall(black_node, downstream): + print(f"skip forwardding to {downstream}...") + runtime_graph.delete_edge(node, downstream) + downstreams.remove(downstream) + except re.error as e: + print("Pattern invalid! Operation cancelled.") - return result_dict + for d_node in downstreams: + if all(i in result_dict for i in runtime_graph.predecessors(d_node)): + inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict) + pending.add( + asyncio.create_task( + self.execute(session, d_node, inputs, runtime_graph, llm_parameters) + ) + ) + nodes_to_keep = [] + for i in ind_nodes: + nodes_to_keep.append(i) + nodes_to_keep.extend(runtime_graph.all_downstreams(i)) + + all_nodes = list(runtime_graph.graph.keys()) + + for node in all_nodes: + if node not in nodes_to_keep: + runtime_graph.delete_node_if_exists(node) + + return result_dict, runtime_graph def process_outputs(self, prev_nodes: List, result_dict: Dict) -> Dict: all_outputs = {} @@ -72,6 +107,7 @@ async def execute( session: aiohttp.client.ClientSession, cur_node: str, inputs: Dict, + runtime_graph: DAG, llm_parameters: LLMParams = LLMParams(), ): # send the cur_node request/reply @@ -97,8 +133,8 @@ def generate(): else: if ( self.services[cur_node].service_type == ServiceType.LLM - and self.predecessors(cur_node) - and "asr" in self.predecessors(cur_node)[0] + and runtime_graph.predecessors(cur_node) + and "asr" in runtime_graph.predecessors(cur_node)[0] ): inputs["query"] = inputs["text"] del inputs["text"] @@ -109,8 +145,8 @@ def generate(): def dump_outputs(self, node, response, result_dict): result_dict[node] = response - def get_all_final_outputs(self, result_dict): + def get_all_final_outputs(self, result_dict, runtime_graph): final_output_dict = {} - for leaf in self.all_leaves(): + for leaf in runtime_graph.all_leaves(): final_output_dict[leaf] = result_dict[leaf] return final_output_dict diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 819cb11c8..a5034aa28 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -10,7 +10,13 @@ from pydantic import Field, conint, conlist -class TextDoc(BaseDoc): +class TopologyInfo: + # will not keep forwarding to the downstream nodes in the black list + # should be a pattern string + downstream_black_list: Optional[list] = [] + + +class TextDoc(BaseDoc, TopologyInfo): text: str diff --git a/comps/guardrails/langchain/guardrails_tgi_gaudi.py b/comps/guardrails/langchain/guardrails_tgi_gaudi.py index 34cecf5d8..9de0193be 100644 --- a/comps/guardrails/langchain/guardrails_tgi_gaudi.py +++ b/comps/guardrails/langchain/guardrails_tgi_gaudi.py @@ -70,9 +70,11 @@ def safety_guard(input: TextDoc) -> TextDoc: policy_violation_level = response_input_guard.split("\n")[1].strip() policy_violations = unsafe_dict[policy_violation_level] print(f"Violated policies: {policy_violations}") - res = TextDoc(text=f"Violated policies: {policy_violations}, please check your input.") + res = TextDoc( + text=f"Violated policies: {policy_violations}, please check your input.", downstream_black_list=[".*"] + ) else: - res = TextDoc(text="safe") + res = TextDoc(text=input.text) return res diff --git a/tests/cores/mega/test_aio.py b/tests/cores/mega/test_aio.py index f86f7f188..fc735e70a 100644 --- a/tests/cores/mega/test_aio.py +++ b/tests/cores/mega/test_aio.py @@ -76,13 +76,13 @@ async def test_schedule(self): task2 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hi, "})) await asyncio.gather(task1, task2) - result_dict1 = task1.result() - result_dict2 = task2.result() + result_dict1, runtime_graph1 = task1.result() + result_dict2, runtime_graph2 = task2.result() self.assertEqual(result_dict1[self.s2.name]["text"], "hello, opea project1!") self.assertEqual(result_dict1[self.s3.name]["text"], "hello, opea project2!") self.assertEqual(result_dict2[self.s2.name]["text"], "hi, opea project1!") self.assertEqual(result_dict2[self.s3.name]["text"], "hi, opea project2!") - self.assertEqual(len(self.service_builder.get_all_final_outputs(result_dict1).keys()), 2) + self.assertEqual(len(self.service_builder.get_all_final_outputs(result_dict1, runtime_graph1).keys()), 2) self.assertEqual(int(time.time() - t), 15) diff --git a/tests/cores/mega/test_base_statistics.py b/tests/cores/mega/test_base_statistics.py index b9b78437f..ef4e7da3e 100644 --- a/tests/cores/mega/test_base_statistics.py +++ b/tests/cores/mega/test_base_statistics.py @@ -46,7 +46,7 @@ async def test_base_statistics(self): for _ in range(2): task1 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hello, "})) await asyncio.gather(task1) - result_dict1 = task1.result() + result_dict1, _ = task1.result() response = requests.get("http://localhost:8083/v1/statistics") res = response.json() diff --git a/tests/cores/mega/test_runtime_graph.py b/tests/cores/mega/test_runtime_graph.py new file mode 100644 index 000000000..1dbe36419 --- /dev/null +++ b/tests/cores/mega/test_runtime_graph.py @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from fastapi.testclient import TestClient + +from comps import ServiceOrchestrator, TextDoc, opea_microservices, register_microservice + + +@register_microservice(name="s1", host="0.0.0.0", port=8080, endpoint="/v1/add") +async def add_s1(request: TextDoc) -> TextDoc: + text = request.text + if "Hi" in text: + text += "OPEA Project!" + return TextDoc(text=text, downstream_black_list=[]) + elif "Bye" in text: + text += "OPEA Project!" + return TextDoc(text=text, downstream_black_list=[".*"]) + elif "Hola" in text: + text += "OPEA Project!" + return TextDoc(text=text, downstream_black_list=["s2"]) + else: + text += "OPEA Project!" + return TextDoc(text=text, downstream_black_list=["s3"]) + + +@register_microservice(name="s2", host="0.0.0.0", port=8081, endpoint="/v1/add") +async def add_s2(request: TextDoc) -> TextDoc: + text = request.text + text += "add s2!" + return TextDoc(text=text) + + +@register_microservice(name="s3", host="0.0.0.0", port=8082, endpoint="/v1/add") +async def add_s3(request: TextDoc) -> TextDoc: + text = request.text + text += "add s3!" + return TextDoc(text=text) + + +@register_microservice(name="s4", host="0.0.0.0", port=8083, endpoint="/v1/add") +async def add_s4(request: TextDoc) -> TextDoc: + text = request.text + text += "add s4!" + return TextDoc(text=text) + + +class TestMicroService(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.client1 = TestClient(opea_microservices["s1"].app) + self.s1 = opea_microservices["s1"] + self.s2 = opea_microservices["s2"] + self.s3 = opea_microservices["s3"] + self.s4 = opea_microservices["s4"] + + self.s1.start() + self.s2.start() + self.s3.start() + self.s4.start() + + self.service_builder = ServiceOrchestrator() + self.service_builder.add(self.s1).add(self.s2).add(self.s3).add(self.s4) + self.service_builder.flow_to(self.s1, self.s2) + self.service_builder.flow_to(self.s1, self.s3) + self.service_builder.flow_to(self.s3, self.s4) + + def tearDown(self): + self.s1.stop() + self.s2.stop() + self.s3.stop() + self.s4.stop() + + async def test_add_route(self): + result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Hi!"}) + assert len(result_dict) == 4 + assert len(runtime_graph.all_leaves()) == 2 + result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Bye!"}) + assert len(result_dict) == 1 + assert len(runtime_graph.all_leaves()) == 1 + result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Hola!"}) + assert len(result_dict) == 3 + assert len(runtime_graph.all_leaves()) == 1 + result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Other!"}) + print(runtime_graph.graph) + assert len(result_dict) == 2 + assert len(runtime_graph.all_leaves()) == 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cores/mega/test_service_orchestrator.py b/tests/cores/mega/test_service_orchestrator.py index 3fc9ce4dd..78a30fc59 100644 --- a/tests/cores/mega/test_service_orchestrator.py +++ b/tests/cores/mega/test_service_orchestrator.py @@ -42,7 +42,7 @@ def tearDown(self): self.s2.stop() async def test_schedule(self): - result_dict = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) + result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!") diff --git a/tests/cores/mega/test_service_orchestrator_with_gateway.py b/tests/cores/mega/test_service_orchestrator_with_gateway.py index 82c73c85a..42bad2a2f 100644 --- a/tests/cores/mega/test_service_orchestrator_with_gateway.py +++ b/tests/cores/mega/test_service_orchestrator_with_gateway.py @@ -44,7 +44,7 @@ def tearDown(self): self.gateway.stop() async def test_schedule(self): - result_dict = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) + result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!") diff --git a/tests/test_workflow_chatqna.py b/tests/test_workflow_chatqna.py index bf893d5df..a2ea0f2d0 100644 --- a/tests/test_workflow_chatqna.py +++ b/tests/test_workflow_chatqna.py @@ -68,8 +68,9 @@ def add_remote_service(self): self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) async def schedule(self): - await self.megaservice.schedule(initial_inputs={"text": "What is the revenue of Nike in 2023?"}) - result_dict = self.megaservice.result_dict + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": "What is the revenue of Nike in 2023?"} + ) print(result_dict)