diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index d92cbd75c..47f0a29a3 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -118,8 +118,8 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters) - for node, response in self.megaservice.result_dict.items(): + result_dict = await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( isinstance(response, StreamingResponse) @@ -127,8 +127,8 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] - response = self.megaservice.result_dict[last_node]["text"] + last_node = self.megaservice.get_all_final_outputs()[-1] + response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() choices.append( @@ -160,8 +160,8 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) - for node, response in self.megaservice.result_dict.items(): + result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( isinstance(response, StreamingResponse) @@ -169,8 +169,8 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] - response = self.megaservice.result_dict[last_node]["text"] + last_node = self.megaservice.get_all_final_outputs()[-1] + response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() choices.append( @@ -207,8 +207,8 @@ async def handle_request(self, request: Request): ### Translated codes: """ prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code) - await self.megaservice.schedule(initial_inputs={"query": prompt}) - for node, response in self.megaservice.result_dict.items(): + result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( isinstance(response, StreamingResponse) @@ -216,8 +216,8 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] - response = self.megaservice.result_dict[last_node]["text"] + last_node = self.megaservice.get_all_final_outputs()[-1] + response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() choices.append( @@ -249,8 +249,8 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, streaming=stream_opt, ) - await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) - for node, response in self.megaservice.result_dict.items(): + result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters) + for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LLM. if ( isinstance(response, StreamingResponse) @@ -258,8 +258,8 @@ async def handle_request(self, request: Request): and self.megaservice.services[node].service_type == ServiceType.LLM ): return response - last_node = self.megaservice.all_leaves()[-1] - response = self.megaservice.result_dict[last_node]["text"] + last_node = self.megaservice.get_all_final_outputs()[-1] + response = result_dict[last_node]["text"] choices = [] usage = UsageInfo() choices.append( diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index 0f0631595..f421b2d44 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -1,9 +1,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import json from typing import Dict, List +import aiohttp import requests from fastapi.responses import StreamingResponse @@ -17,7 +19,6 @@ class ServiceOrchestrator(DAG): def __init__(self) -> None: self.services = {} # all services, id -> service - self.result_dict = {} super().__init__() def add(self, service): @@ -37,23 +38,41 @@ def flow_to(self, from_service, to_service): return False async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()): - for node in self.topological_sort(): - if node in self.ind_nodes(): - inputs = initial_inputs - else: - inputs = self.process_outputs(self.predecessors(node)) - response = await self.execute(node, inputs, llm_parameters) - self.dump_outputs(node, response) - - def process_outputs(self, prev_nodes: List) -> Dict: + result_dict = {} + + async with aiohttp.ClientSession(trust_env=True) as session: + pending = {asyncio.create_task(self.execute(session, node, initial_inputs)) for node in self.ind_nodes()} + + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for done_task in done: + response, node = await done_task + self.dump_outputs(node, response, result_dict) + + # traverse the current node's downstream nodes and execute if all one's predecessors are finished + downstreams = self.downstream(node) + for d_node in downstreams: + if all(i in result_dict for i in self.predecessors(d_node)): + inputs = self.process_outputs(self.predecessors(d_node), result_dict) + pending.add(asyncio.create_task(self.execute(session, d_node, inputs, llm_parameters))) + + return result_dict + + def process_outputs(self, prev_nodes: List, result_dict: Dict) -> Dict: all_outputs = {} # assume all prev_nodes outputs' keys are not duplicated for prev_node in prev_nodes: - all_outputs.update(self.result_dict[prev_node]) + all_outputs.update(result_dict[prev_node]) return all_outputs - async def execute(self, cur_node: str, inputs: Dict, llm_parameters: LLMParams = LLMParams()): + async def execute( + self, + session: aiohttp.client.ClientSession, + cur_node: str, + inputs: Dict, + llm_parameters: LLMParams = LLMParams(), + ): # send the cur_node request/reply endpoint = self.services[cur_node].endpoint_path llm_parameters_dict = llm_parameters.dict() @@ -61,6 +80,7 @@ async def execute(self, cur_node: str, inputs: Dict, llm_parameters: LLMParams = if inputs.get(field) != value: inputs[field] = value if self.services[cur_node].service_type == ServiceType.LLM and llm_parameters.streaming: + # Still leave to sync requests.post for StreamingResponse response = requests.post( url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000 ) @@ -71,15 +91,17 @@ def generate(): if chunk: yield chunk - return StreamingResponse(generate(), media_type="text/event-stream") + return StreamingResponse(generate(), media_type="text/event-stream"), cur_node else: - response = requests.post(url=endpoint, data=json.dumps(inputs), proxies={"http": None}) - print(response) - return response.json() + async with session.post(endpoint, json=inputs) as response: + print(response.status) + return await response.json(), cur_node - def dump_outputs(self, node, response): - self.result_dict[node] = response + def dump_outputs(self, node, response, result_dict): + result_dict[node] = response - def get_all_final_outputs(self): + def get_all_final_outputs(self, result_dict): + final_output_dict = {} for leaf in self.all_leaves(): - print(self.result_dict[leaf]) + final_output_dict[leaf] = result_dict[leaf] + return final_output_dict diff --git a/requirements.txt b/requirements.txt index 1fb7d2181..12e0cdd9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiohttp docarray fastapi httpx diff --git a/tests/cores/mega/test_aio.py b/tests/cores/mega/test_aio.py new file mode 100644 index 000000000..f86f7f188 --- /dev/null +++ b/tests/cores/mega/test_aio.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import time +import unittest + +from comps import ServiceOrchestrator, TextDoc, opea_microservices, register_microservice + + +@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add") +async def s1_add(request: TextDoc) -> TextDoc: + time.sleep(5) + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += "opea" + return {"text": text} + + +@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add") +async def s2_add(request: TextDoc) -> TextDoc: + time.sleep(5) + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += " project1!" + return {"text": text} + + +@register_microservice(name="s3", host="0.0.0.0", port=8085, endpoint="/v1/add") +async def s3_add(request: TextDoc) -> TextDoc: + time.sleep(5) + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += " project2!" + return {"text": text} + + +class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.s1 = opea_microservices["s1"] + self.s2 = opea_microservices["s2"] + self.s3 = opea_microservices["s3"] + self.s1.start() + self.s2.start() + self.s3.start() + + self.service_builder = ServiceOrchestrator() + + self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]).add(opea_microservices["s3"]) + self.service_builder.flow_to(self.s1, self.s2) + self.service_builder.flow_to(self.s1, self.s3) + + def tearDown(self): + self.s1.stop() + self.s2.stop() + self.s3.stop() + + async def test_schedule(self): + t = time.time() + task1 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hello, "})) + task2 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hi, "})) + await asyncio.gather(task1, task2) + + result_dict1 = task1.result() + result_dict2 = task2.result() + self.assertEqual(result_dict1[self.s2.name]["text"], "hello, opea project1!") + self.assertEqual(result_dict1[self.s3.name]["text"], "hello, opea project2!") + self.assertEqual(result_dict2[self.s2.name]["text"], "hi, opea project1!") + self.assertEqual(result_dict2[self.s3.name]["text"], "hi, opea project2!") + self.assertEqual(len(self.service_builder.get_all_final_outputs(result_dict1).keys()), 2) + self.assertEqual(int(time.time() - t), 15) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cores/mega/test_service_orchestrator.py b/tests/cores/mega/test_service_orchestrator.py index ea3251dac..3fc9ce4dd 100644 --- a/tests/cores/mega/test_service_orchestrator.py +++ b/tests/cores/mega/test_service_orchestrator.py @@ -42,8 +42,7 @@ def tearDown(self): self.s2.stop() async def test_schedule(self): - await self.service_builder.schedule(initial_inputs={"text": "hello, "}) - result_dict = self.service_builder.result_dict + result_dict = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!") diff --git a/tests/cores/mega/test_service_orchestrator_with_gateway.py b/tests/cores/mega/test_service_orchestrator_with_gateway.py index 0289b046a..82c73c85a 100644 --- a/tests/cores/mega/test_service_orchestrator_with_gateway.py +++ b/tests/cores/mega/test_service_orchestrator_with_gateway.py @@ -44,8 +44,7 @@ def tearDown(self): self.gateway.stop() async def test_schedule(self): - await self.service_builder.schedule(initial_inputs={"text": "hello, "}) - result_dict = self.service_builder.result_dict + result_dict = await self.service_builder.schedule(initial_inputs={"text": "hello, "}) self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!")