Skip to content

Commit

Permalink
Add Retrieval gateway in core to support IndexRetrivel Megaservice (#314
Browse files Browse the repository at this point in the history
)

* Add Retrieval gateway

Signed-off-by: Chendi.Xue <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update gateway to latest protocal

Signed-off-by: Chendi.Xue <[email protected]>

* tested with DocIndexer, rebased gateway is now workable

Signed-off-by: Chendi.Xue <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 56daf95)
  • Loading branch information
xuechendi authored and chensuyue committed Aug 21, 2024
1 parent f5f360d commit 72b60bb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TranslationGateway,
SearchQnAGateway,
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VisualQnAGateway,
)
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MegaServiceEndpoint(Enum):
DOC_SUMMARY = "/v1/docsum"
SEARCH_QNA = "/v1/searchqna"
TRANSLATION = "/v1/translation"
RETRIEVALTOOL = "/v1/retrievaltool"
FAQ_GEN = "/v1/faqgen"
# Follow OPENAI
EMBEDDINGS = "/v1/embeddings"
Expand Down
41 changes: 40 additions & 1 deletion comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import os
from io import BytesIO
from typing import Union

import requests
from fastapi import Request
Expand All @@ -16,9 +17,10 @@
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -529,3 +531,40 @@ async def handle_request(self, request: Request):
)
)
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

def __init__(self, megaservice, host="0.0.0.0", port=8889):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest,
Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse
)

async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
print("response is ", response)
return response

0 comments on commit 72b60bb

Please sign in to comment.