From e18369ba0d7f9ab5ddc35ed6dedcf03899766902 Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Sat, 14 Dec 2024 13:19:51 +0800 Subject: [PATCH] remove examples gateway. (#1250) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- AudioQnA/audioqna.py | 14 +++++++++----- AudioQnA/audioqna_multilang.py | 15 ++++++++++----- AvatarChatbot/avatarchatbot.py | 14 +++++++++----- CodeGen/codegen.py | 17 +++++++++++------ CodeTrans/code_translation.py | 14 +++++++++----- DocIndexRetriever/retrieval_tool.py | 14 +++++++++----- EdgeCraftRAG/chatqna.py | 14 +++++++++----- GraphRAG/graphrag.py | 17 +++++++++++------ MultimodalQnA/multimodalqna.py | 15 +++++++++------ SearchQnA/searchqna.py | 17 +++++++++++------ Translation/translation.py | 14 +++++++++----- VideoQnA/videoqna.py | 17 +++++++++++------ VisualQnA/visualqna.py | 17 +++++++++++------ 13 files changed, 128 insertions(+), 71 deletions(-) diff --git a/AudioQnA/audioqna.py b/AudioQnA/audioqna.py index fbede59489..8e2e68fbe8 100644 --- a/AudioQnA/audioqna.py +++ b/AudioQnA/audioqna.py @@ -4,7 +4,7 @@ import asyncio import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse from comps.cores.proto.docarray import LLMParams from fastapi import Request @@ -18,11 +18,12 @@ TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088)) -class AudioQnAService(Gateway): +class AudioQnAService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.AUDIO_QNA) def add_remote_service(self): asr = MicroService( @@ -78,14 +79,17 @@ async def handle_request(self, request: Request): return response def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + endpoint=self.endpoint, input_datatype=AudioChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/AudioQnA/audioqna_multilang.py b/AudioQnA/audioqna_multilang.py index 33a1e1d61a..7d8c1ae801 100644 --- a/AudioQnA/audioqna_multilang.py +++ b/AudioQnA/audioqna_multilang.py @@ -5,7 +5,7 @@ import base64 import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse from comps.cores.proto.docarray import LLMParams from fastapi import Request @@ -54,7 +54,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di return data -class AudioQnAService(Gateway): +class AudioQnAService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -62,6 +62,8 @@ def __init__(self, host="0.0.0.0", port=8000): ServiceOrchestrator.align_outputs = align_outputs self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.AUDIO_QNA) + def add_remote_service(self): asr = MicroService( name="asr", @@ -118,14 +120,17 @@ async def handle_request(self, request: Request): return response def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.AUDIO_QNA), + endpoint=self.endpoint, input_datatype=AudioChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/AvatarChatbot/avatarchatbot.py b/AvatarChatbot/avatarchatbot.py index 0893fc6f93..54bd0d29af 100644 --- a/AvatarChatbot/avatarchatbot.py +++ b/AvatarChatbot/avatarchatbot.py @@ -5,7 +5,7 @@ import os import sys -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse from comps.cores.proto.docarray import LLMParams from fastapi import Request @@ -29,11 +29,12 @@ def check_env_vars(env_var_list): print("All environment variables are set.") -class AvatarChatbotService(Gateway): +class AvatarChatbotService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.AVATAR_CHATBOT) def add_remote_service(self): asr = MicroService( @@ -97,14 +98,17 @@ async def handle_request(self, request: Request): return response def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.AVATAR_CHATBOT), + endpoint=self.endpoint, input_datatype=AudioChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/CodeGen/codegen.py b/CodeGen/codegen.py index 5ae4329d08..9769d682da 100644 --- a/CodeGen/codegen.py +++ b/CodeGen/codegen.py @@ -4,7 +4,8 @@ import asyncio import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType +from comps.cores.mega.utils import handle_message from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -21,11 +22,12 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeGenService(Gateway): +class CodeGenService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.CODE_GEN) def add_remote_service(self): llm = MicroService( @@ -42,7 +44,7 @@ async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + prompt = handle_message(chat_request.messages) parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -78,14 +80,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="codegen", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.CODE_GEN), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/CodeTrans/code_translation.py b/CodeTrans/code_translation.py index c163c847fa..5e6ba79408 100644 --- a/CodeTrans/code_translation.py +++ b/CodeTrans/code_translation.py @@ -4,7 +4,7 @@ import asyncio import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -20,11 +20,12 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class CodeTransService(Gateway): +class CodeTransService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.CODE_TRANS) def add_remote_service(self): llm = MicroService( @@ -77,14 +78,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.CODE_TRANS), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/DocIndexRetriever/retrieval_tool.py b/DocIndexRetriever/retrieval_tool.py index 9581612a50..3c3d8a5571 100644 --- a/DocIndexRetriever/retrieval_tool.py +++ b/DocIndexRetriever/retrieval_tool.py @@ -6,7 +6,7 @@ import os from typing import Union -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc from fastapi import Request @@ -21,11 +21,12 @@ RERANK_SERVICE_PORT = os.getenv("RERANK_SERVICE_PORT", 8000) -class RetrievalToolService(Gateway): +class RetrievalToolService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.RETRIEVALTOOL) def add_remote_service(self): embedding = MicroService( @@ -116,14 +117,17 @@ def parser_input(data, TypeClass, key): return response def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.RETRIEVALTOOL), + endpoint=self.endpoint, input_datatype=Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], output_datatype=Union[RerankedDoc, LLMParamsDoc], ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() def add_remote_service_without_rerank(self): embedding = MicroService( diff --git a/EdgeCraftRAG/chatqna.py b/EdgeCraftRAG/chatqna.py index 31c701b5fa..d9441d09f9 100644 --- a/EdgeCraftRAG/chatqna.py +++ b/EdgeCraftRAG/chatqna.py @@ -9,7 +9,7 @@ PIPELINE_SERVICE_HOST_IP = os.getenv("PIPELINE_SERVICE_HOST_IP", "127.0.0.1") PIPELINE_SERVICE_PORT = int(os.getenv("PIPELINE_SERVICE_PORT", 16010)) -from comps import Gateway, MegaServiceEndpoint +from comps import MegaServiceEndpoint, ServiceRoleType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -22,11 +22,12 @@ from fastapi.responses import StreamingResponse -class EdgeCraftRagService(Gateway): +class EdgeCraftRagService: def __init__(self, host="0.0.0.0", port=16010): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.CHAT_QNA) def add_remote_service(self): edgecraftrag = MicroService( @@ -72,14 +73,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="edgecraftrag", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.CHAT_QNA), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/GraphRAG/graphrag.py b/GraphRAG/graphrag.py index f675095147..6d3e3b9829 100644 --- a/GraphRAG/graphrag.py +++ b/GraphRAG/graphrag.py @@ -6,7 +6,8 @@ import os import re -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType +from comps.cores.mega.utils import handle_message from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -127,7 +128,7 @@ def align_generator(self, gen, **kwargs): yield "data: [DONE]\n\n" -class GraphRAGService(Gateway): +class GraphRAGService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port @@ -135,6 +136,7 @@ def __init__(self, host="0.0.0.0", port=8000): ServiceOrchestrator.align_outputs = align_outputs ServiceOrchestrator.align_generator = align_generator self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.GRAPH_RAG) def add_remote_service(self): retriever = MicroService( @@ -180,7 +182,7 @@ def parser_input(data, TypeClass, key): raise ValueError(f"Unknown request type: {data}") if chat_request is None: raise ValueError(f"Unknown request type: {data}") - prompt = self._handle_message(chat_request.messages) + prompt = handle_message(chat_request.messages) parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -223,14 +225,17 @@ def parser_input(data, TypeClass, key): return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.GRAPH_RAG), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/MultimodalQnA/multimodalqna.py b/MultimodalQnA/multimodalqna.py index 87565a5b8a..fb70128613 100644 --- a/MultimodalQnA/multimodalqna.py +++ b/MultimodalQnA/multimodalqna.py @@ -7,7 +7,7 @@ from io import BytesIO import requests -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -29,7 +29,7 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9399)) -class MultimodalQnAService(Gateway): +class MultimodalQnAService: asr_port = int(os.getenv("ASR_SERVICE_PORT", 3001)) asr_endpoint = os.getenv("ASR_SERVICE_ENDPOINT", "http://0.0.0.0:{}/v1/audio/transcriptions".format(asr_port)) @@ -38,6 +38,7 @@ def __init__(self, host="0.0.0.0", port=8000): self.port = port self.lvm_megaservice = ServiceOrchestrator() self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.MULTIMODAL_QNA) def add_remote_service(self): mm_embedding = MicroService( @@ -74,7 +75,6 @@ def add_remote_service(self): # for lvm megaservice self.lvm_megaservice.add(lvm) - # this overrides _handle_message method of Gateway def _handle_message(self, messages): images = [] audios = [] @@ -303,14 +303,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="multimodalqna", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.MULTIMODAL_QNA), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/SearchQnA/searchqna.py b/SearchQnA/searchqna.py index 1a04a97a60..318aae3a5e 100644 --- a/SearchQnA/searchqna.py +++ b/SearchQnA/searchqna.py @@ -3,7 +3,8 @@ import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType +from comps.cores.mega.utils import handle_message from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -26,11 +27,12 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class SearchQnAService(Gateway): +class SearchQnAService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.SEARCH_QNA) def add_remote_service(self): embedding = MicroService( @@ -74,7 +76,7 @@ async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + prompt = handle_message(chat_request.messages) parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -110,14 +112,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="searchqna", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.SEARCH_QNA), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/Translation/translation.py b/Translation/translation.py index 92e8b6f0dd..8a5d8aad6a 100644 --- a/Translation/translation.py +++ b/Translation/translation.py @@ -15,7 +15,7 @@ import asyncio import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -31,11 +31,12 @@ LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) -class TranslationService(Gateway): +class TranslationService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.TRANSLATION) def add_remote_service(self): llm = MicroService( @@ -87,14 +88,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="translation", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.TRANSLATION), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/VideoQnA/videoqna.py b/VideoQnA/videoqna.py index 7632a1c10b..3b699faa7a 100644 --- a/VideoQnA/videoqna.py +++ b/VideoQnA/videoqna.py @@ -3,7 +3,8 @@ import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType +from comps.cores.mega.utils import handle_message from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -26,11 +27,12 @@ LVM_SERVICE_PORT = int(os.getenv("LVM_SERVICE_PORT", 9000)) -class VideoQnAService(Gateway): +class VideoQnAService: def __init__(self, host="0.0.0.0", port=8888): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.VIDEO_RAG_QNA) def add_remote_service(self): embedding = MicroService( @@ -74,7 +76,7 @@ async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", False) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + prompt = handle_message(chat_request.messages) parameters = LLMParams( max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -110,14 +112,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="videoqna", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.VIDEO_RAG_QNA), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__": diff --git a/VisualQnA/visualqna.py b/VisualQnA/visualqna.py index f6519c1d27..312239615a 100644 --- a/VisualQnA/visualqna.py +++ b/VisualQnA/visualqna.py @@ -3,7 +3,8 @@ import os -from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType +from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType +from comps.cores.mega.utils import handle_message from comps.cores.proto.api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -20,11 +21,12 @@ LVM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9399)) -class VisualQnAService(Gateway): +class VisualQnAService: def __init__(self, host="0.0.0.0", port=8000): self.host = host self.port = port self.megaservice = ServiceOrchestrator() + self.endpoint = str(MegaServiceEndpoint.VISUAL_QNA) def add_remote_service(self): llm = MicroService( @@ -41,7 +43,7 @@ async def handle_request(self, request: Request): data = await request.json() stream_opt = data.get("stream", False) chat_request = ChatCompletionRequest.parse_obj(data) - prompt, images = self._handle_message(chat_request.messages) + prompt, images = handle_message(chat_request.messages) parameters = LLMParams( max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -77,14 +79,17 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) def start(self): - super().__init__( - megaservice=self.megaservice, + self.service = MicroService( + self.__class__.__name__, + service_role=ServiceRoleType.MEGASERVICE, host=self.host, port=self.port, - endpoint=str(MegaServiceEndpoint.VISUAL_QNA), + endpoint=self.endpoint, input_datatype=ChatCompletionRequest, output_datatype=ChatCompletionResponse, ) + self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) + self.service.start() if __name__ == "__main__":