diff --git a/comps/__init__.py b/comps/__init__.py index cb7ed7a287..15fb64e7b0 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -36,6 +36,7 @@ TranslationGateway, SearchQnAGateway, AudioQnAGateway, + RetrievalToolGateway, FaqGenGateway, VisualQnAGateway, ) diff --git a/comps/cores/mega/constants.py b/comps/cores/mega/constants.py index 05eab52844..10863c1495 100644 --- a/comps/cores/mega/constants.py +++ b/comps/cores/mega/constants.py @@ -43,6 +43,7 @@ class MegaServiceEndpoint(Enum): DOC_SUMMARY = "/v1/docsum" SEARCH_QNA = "/v1/searchqna" TRANSLATION = "/v1/translation" + RETRIEVALTOOL = "/v1/retrievaltool" FAQ_GEN = "/v1/faqgen" # Follow OPENAI EMBEDDINGS = "/v1/embeddings" diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 6eb069e6ec..cc8eaf5d22 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -4,6 +4,7 @@ import base64 import os from io import BytesIO +from typing import Union import requests from fastapi import Request @@ -16,9 +17,10 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, + EmbeddingRequest, UsageInfo, ) -from ..proto.docarray import LLMParams +from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType from .micro_service import MicroService @@ -529,3 +531,40 @@ async def handle_request(self, request: Request): ) ) return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) + + +class RetrievalToolGateway(Gateway): + """embed+retrieve+rerank.""" + + def __init__(self, megaservice, host="0.0.0.0", port=8889): + super().__init__( + megaservice, + host, + port, + str(MegaServiceEndpoint.RETRIEVALTOOL), + Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest, + Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse + ) + + async def handle_request(self, request: Request): + def parser_input(data, TypeClass, key): + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query + + data = await request.json() + query = None + for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node] + print("response is ", response) + return response