Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Retrieval gateway in core to support IndexRetrivel Megaservice #314

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TranslationGateway,
SearchQnAGateway,
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VisualQnAGateway,
)
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MegaServiceEndpoint(Enum):
DOC_SUMMARY = "/v1/docsum"
SEARCH_QNA = "/v1/searchqna"
TRANSLATION = "/v1/translation"
RETRIEVALTOOL = "/v1/retrievaltool"
FAQ_GEN = "/v1/faqgen"
# Follow OPENAI
EMBEDDINGS = "/v1/embeddings"
Expand Down
41 changes: 40 additions & 1 deletion comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import os
from io import BytesIO
from typing import Union

import requests
from fastapi import Request
Expand All @@ -16,9 +17,10 @@
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -529,3 +531,40 @@
)
)
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

def __init__(self, megaservice, host="0.0.0.0", port=8889):
super().__init__(

Check warning on line 540 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L540

Added line #L540 was not covered by tests
megaservice,
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest,
Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse
)

async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query

Check warning on line 556 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L550-L556

Added lines #L550 - L556 were not covered by tests

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
print("response is ", response)
return response

Check warning on line 570 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L558-L570

Added lines #L558 - L570 were not covered by tests