Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Megaservice support for MMRAG VideoRAGQnA usecase #603

Merged
merged 14 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VideoRAGQnAGateway,
VisualQnAGateway,
MultimodalRAGWithVideosGateway,
)
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MegaServiceEndpoint(Enum):
CHAT_QNA = "/v1/chatqna"
AUDIO_QNA = "/v1/audioqna"
VISUAL_QNA = "/v1/visualqna"
VIDEO_RAG_QNA = "/v1/videoragqna"
CODE_GEN = "/v1/codegen"
CODE_TRANS = "/v1/codetrans"
DOC_SUMMARY = "/v1/docsum"
Expand Down
49 changes: 49 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,55 @@
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class VideoRAGQnAGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.VIDEO_RAG_QNA),
ChatCompletionRequest,
ChatCompletionResponse,
)

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(

Check warning on line 567 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L563-L567

Added lines #L563 - L567 were not covered by tests
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(

Check warning on line 575 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L575

Added line #L575 was not covered by tests
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():

Check warning on line 578 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L578

Added line #L578 was not covered by tests
# Here it suppose the last microservice in the megaservice is LVM.
if (

Check warning on line 580 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L580

Added line #L580 was not covered by tests
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(

Check warning on line 590 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L585-L590

Added lines #L585 - L590 were not covered by tests
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="videoragqna", choices=choices, usage=usage)

Check warning on line 597 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L597

Added line #L597 was not covered by tests


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

Expand Down
3 changes: 2 additions & 1 deletion comps/embeddings/multimodal_clip/embeddings_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def embed_query(self, texts):
return text_features

def get_embedding_length(self):
return len(self.embed_query("sample_text"))
text_features = self.embed_query("sample_text")
return text_features.shape[1]

def get_image_embeddings(self, images):
"""Input is list of images."""
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/langchain/vdms/retriever_vdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def retrieve(input: EmbedDoc) -> SearchedMultimodalDoc:
# Create vectorstore

if use_clip:
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 4})
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64})
dimensions = embeddings.get_embedding_length()
elif tei_embedding_endpoint:
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint, huggingfacehub_api_token=hf_token)
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/langchain/vdms/vdms_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def get_boolean_env_var(var_name, default_value=False):
# VDMS_SCHEMA = os.getenv("VDMS_SCHEMA", "vdms_schema.yml")
# INDEX_SCHEMA = os.path.join(parent_dir, VDMS_SCHEMA)
SEARCH_ENGINE = "FaissFlat"
DISTANCE_STRATEGY = "L2"
DISTANCE_STRATEGY = "IP"
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from fastapi.responses import StreamingResponse

from comps import (
ServiceOrchestrator,
ServiceType,
TextDoc,
VideoRAGQnAGateway,
opea_microservices,
register_microservice,
)
from comps.cores.proto.docarray import LLMParams


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "opea "
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.LVM)
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]

def streamer(text):
yield f"{text}".encode("utf-8")
for i in range(3):
yield "project!".encode("utf-8")

return StreamingResponse(streamer(text), media_type="text/event-stream")


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = VideoRAGQnAGateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

async def test_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, "}, llm_parameters=LLMParams(streaming=True)
)
streaming_response = result_dict[self.s2.name]

if isinstance(streaming_response, StreamingResponse):
content = b""
async for chunk in streaming_response.body_iterator:
content += chunk
final_text = content.decode("utf-8")

print("Streamed content from s2: ", final_text)

expected_result = "hello, opea project!project!project!"
self.assertEqual(final_text, expected_result)


if __name__ == "__main__":
unittest.main()
Loading