Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Megaservice support for MMRAG VideoRAGQnA usecase #603

Merged
merged 14 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VideoRAGQnAGateway,
VisualQnAGateway,
)

Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MegaServiceEndpoint(Enum):
CHAT_QNA = "/v1/chatqna"
AUDIO_QNA = "/v1/audioqna"
VISUAL_QNA = "/v1/visualqna"
VIDEO_RAG_QNA = "/v1/videoragqna"
CODE_GEN = "/v1/codegen"
CODE_TRANS = "/v1/codetrans"
DOC_SUMMARY = "/v1/docsum"
Expand Down
49 changes: 49 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,55 @@
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class VideoRAGQnAGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.VIDEO_RAG_QNA),
ChatCompletionRequest,
ChatCompletionResponse,
)

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(

Check warning on line 566 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L562-L566

Added lines #L562 - L566 were not covered by tests
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(

Check warning on line 574 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L574

Added line #L574 was not covered by tests
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():

Check warning on line 577 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L577

Added line #L577 was not covered by tests
# Here it suppose the last microservice in the megaservice is LVM.
if (

Check warning on line 579 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L579

Added line #L579 was not covered by tests
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(

Check warning on line 589 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L584-L589

Added lines #L584 - L589 were not covered by tests
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="videoragqna", choices=choices, usage=usage)

Check warning on line 596 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L596

Added line #L596 was not covered by tests


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from fastapi.responses import StreamingResponse

from comps import (
ServiceOrchestrator,
ServiceType,
TextDoc,
VideoRAGQnAGateway,
opea_microservices,
register_microservice,
)
from comps.cores.proto.docarray import LLMParams


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "opea "
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.LVM)
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]

def streamer(text):
yield f"{text}".encode("utf-8")
for i in range(3):
yield "project!".encode("utf-8")

return StreamingResponse(streamer(text), media_type="text/event-stream")


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = VideoRAGQnAGateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

async def test_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, "}, llm_parameters=LLMParams(streaming=True)
)
streaming_response = result_dict[self.s2.name]

if isinstance(streaming_response, StreamingResponse):
content = b""
async for chunk in streaming_response.body_iterator:
content += chunk
final_text = content.decode("utf-8")

print("Streamed content from s2: ", final_text)

expected_result = "hello, opea project!project!project!"
self.assertEqual(final_text, expected_result)


if __name__ == "__main__":
unittest.main()