diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 10decfa1f..69324a650 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -544,18 +544,24 @@ def __init__(self, megaservice, host="0.0.0.0", port=8889): ) async def handle_request(self, request: Request): + def parser_input(data, TypeClass, key): + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query + data = await request.json() - if isinstance(request, TextDoc): - chat_request = TextDoc.parse_obj(data) - query = chat_request.text - elif isinstance(request, EmbeddingRequest): - chat_request = EmbeddingRequest.parse_obj(data) - query = chat_request.input - elif isinstance(request, ChatCompletionRequest): - chat_request = ChatCompletionRequest.parse_obj(data) - query = chat_request.input - result_dict = await self.megaservice.schedule(initial_inputs={"text": query}) - for node, response in result_dict.items(): - print("Node: {}\nResponse: {}".format(node, response)) - if self.megaservice.services[node].service_type == ServiceType.RERANK: - return response + query = None + for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node] + print("response is ", response) + return response