From 93babf6433aa247a15c208b32ac50ca2e77b18d1 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Mon, 19 Aug 2024 11:38:32 -0500 Subject: [PATCH] Add Retrieval gateway in core to support IndexRetrivel Megaservice (#314) * Add Retrieval gateway Signed-off-by: Chendi.Xue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update gateway to latest protocal Signed-off-by: Chendi.Xue * tested with DocIndexer, rebased gateway is now workable Signed-off-by: Chendi.Xue * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Chendi.Xue Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: siddhivelankar23 --- comps/__init__.py | 1 + comps/cores/mega/constants.py | 1 + comps/cores/mega/gateway.py | 41 ++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/comps/__init__.py b/comps/__init__.py index 01976be45..3b51046d9 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -40,6 +40,7 @@ TranslationGateway, SearchQnAGateway, AudioQnAGateway, + RetrievalToolGateway, FaqGenGateway, VisualQnAGateway, ) diff --git a/comps/cores/mega/constants.py b/comps/cores/mega/constants.py index 05eab5284..10863c149 100644 --- a/comps/cores/mega/constants.py +++ b/comps/cores/mega/constants.py @@ -43,6 +43,7 @@ class MegaServiceEndpoint(Enum): DOC_SUMMARY = "/v1/docsum" SEARCH_QNA = "/v1/searchqna" TRANSLATION = "/v1/translation" + RETRIEVALTOOL = "/v1/retrievaltool" FAQ_GEN = "/v1/faqgen" # Follow OPENAI EMBEDDINGS = "/v1/embeddings" diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 6eb069e6e..cc8eaf5d2 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -4,6 +4,7 @@ import base64 import os from io import BytesIO +from typing import Union import requests from fastapi import Request @@ -16,9 +17,10 @@ ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, + EmbeddingRequest, UsageInfo, ) -from ..proto.docarray import LLMParams +from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType from .micro_service import MicroService @@ -529,3 +531,40 @@ async def handle_request(self, request: Request): ) ) return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) + + +class RetrievalToolGateway(Gateway): + """embed+retrieve+rerank.""" + + def __init__(self, megaservice, host="0.0.0.0", port=8889): + super().__init__( + megaservice, + host, + port, + str(MegaServiceEndpoint.RETRIEVALTOOL), + Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest, + Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse + ) + + async def handle_request(self, request: Request): + def parser_input(data, TypeClass, key): + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query + + data = await request.json() + query = None + for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query}) + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node] + print("response is ", response) + return response