Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Megaservice support for MMRAG - MultimodalRAGQnAWithVideos usecase #626

Merged
merged 55 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ac30d1a
updates
tileintel Aug 29, 2024
77632f2
cosmetic
siddhivelankar23 Aug 30, 2024
049c78d
Merge pull request #4 from tileintel/main
tileintel Sep 3, 2024
4c3ea33
update redis schema
siddhivelankar23 Sep 3, 2024
1688033
update multimodal config and docker compose retriever
siddhivelankar23 Sep 3, 2024
bc1699f
update requirements
siddhivelankar23 Sep 3, 2024
5ecbfdb
update retriever redis
siddhivelankar23 Sep 3, 2024
bc23290
multimodal retriever implementation
siddhivelankar23 Sep 3, 2024
c56151d
test for multimodal retriever
siddhivelankar23 Sep 3, 2024
4c26828
include prompt preparation for multimodal rag on videos application
sjagtap1803 Sep 3, 2024
fefcf72
fix template
sjagtap1803 Sep 3, 2024
4280393
add test for llava for mm_rag_on_videos
sjagtap1803 Sep 3, 2024
3c4b505
update test
sjagtap1803 Sep 3, 2024
70993d9
Merge branch 'main' into retriever_lvm_update
tileintel Sep 3, 2024
5aa4c02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
542e6cd
first update on gateaway
sjagtap1803 Sep 3, 2024
01b2138
fix index not found
sjagtap1803 Sep 4, 2024
08ef0ec
Merge branch 'main' into retriever_lvm_update
tileintel Sep 4, 2024
3c6e471
Merge branch 'main' into retriever_lvm_update
tileintel Sep 4, 2024
1d5c67b
Merge pull request #5 from tileintel/retriever_lvm_update
tileintel Sep 4, 2024
4d353e8
add LVMSearchedMultimodalDoc
sjagtap1803 Sep 4, 2024
b69cef1
Merge branch 'main' into retriever_lvm_update
tileintel Sep 4, 2024
ac57ca0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2024
78d3902
Merge pull request #6 from tileintel/retriever_lvm_update
tileintel Sep 4, 2024
2c1f0ba
implement gateway for MultimodalRagQnAWithVideos
siddhivelankar23 Sep 4, 2024
dc631db
remove INDEX_SCHEMA
siddhivelankar23 Sep 4, 2024
ccf3da7
Merge branch 'main' into retriever_lvm_update
tileintel Sep 4, 2024
0965882
Merge branch 'main' into retriever_lvm_update
tileintel Sep 5, 2024
5a3c764
update MultimodalRAGQnAWithVideosGateway with 2 megaservices
sjagtap1803 Sep 5, 2024
5842ed1
revise folder structure to comps/retrievers/langchain/redis_multimodal
siddhivelankar23 Sep 5, 2024
9d83019
Merge pull request #7 from tileintel/retriever_lvm_update
tileintel Sep 5, 2024
afb3840
update test
siddhivelankar23 Sep 5, 2024
426ed04
add unittest for multimodalrag_qna_with_videos_gateway
siddhivelankar23 Sep 5, 2024
030120e
update test mmrag qna with videos
tileintel Sep 5, 2024
1a49d88
change port of redis to resolve CI test
siddhivelankar23 Sep 5, 2024
44194bd
update test
siddhivelankar23 Sep 5, 2024
9c485e9
update lvms test
siddhivelankar23 Sep 5, 2024
1d8598e
update test
siddhivelankar23 Sep 5, 2024
b757b39
update test
siddhivelankar23 Sep 5, 2024
98afd97
update test for multimodal rag qna with videos gateway
siddhivelankar23 Sep 5, 2024
d15d430
Merge branch 'main' into mmrag_videos_megaservice-dev
tileintel Sep 5, 2024
407ff2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
fb0b834
Merge pull request #8 from tileintel/retriever_lvm_update
tileintel Sep 5, 2024
7478b06
add more test to increase coverage
tileintel Sep 5, 2024
117b8f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
b787eda
cosmetic
tileintel Sep 5, 2024
be9124f
Merge branch 'mmrag_videos_megaservice-dev' of https://github.com/til…
tileintel Sep 5, 2024
f9fd4b4
add more test
tileintel Sep 5, 2024
5e899c7
Merge branch 'main' into mmrag_videos_megaservice-dev
tileintel Sep 5, 2024
50dcd31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2024
693b1c9
update name of gateway
tileintel Sep 6, 2024
9b79373
Merge branch 'mmrag_videos_megaservice-dev' of https://github.com/til…
tileintel Sep 6, 2024
c8ee3e8
Merge branch 'main' into mmrag_videos_megaservice-dev
tileintel Sep 6, 2024
241cd31
Merge branch 'main' into mmrag_videos_megaservice-dev
Spycsh Sep 6, 2024
76109ef
Merge branch 'main' into mmrag_videos_megaservice-dev
tileintel Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
RetrievalToolGateway,
FaqGenGateway,
VisualQnAGateway,
MultimodalRAGWithVideosGateway,
)

# Telemetry
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class MegaServiceEndpoint(Enum):
CODE_TRANS = "/v1/codetrans"
DOC_SUMMARY = "/v1/docsum"
SEARCH_QNA = "/v1/searchqna"
MULTIMODAL_RAG_WITH_VIDEOS = "/v1/mmragvideoqna"
TRANSLATION = "/v1/translation"
RETRIEVALTOOL = "/v1/retrievaltool"
FAQ_GEN = "/v1/faqgen"
Expand Down
157 changes: 157 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
messages_dict[msg_role] = message["content"]
else:
raise ValueError(f"Unknown role: {msg_role}")

if system_prompt:
prompt = system_prompt + "\n"
for role, message in messages_dict.items():
Expand Down Expand Up @@ -582,3 +583,159 @@
response = result_dict[last_node]
print("response is ", response)
return response


class MultimodalRAGWithVideosGateway(Gateway):
def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999):
self.lvm_megaservice = lvm_megaservice
super().__init__(
multimodal_rag_megaservice,
host,
port,
str(MegaServiceEndpoint.MULTIMODAL_RAG_WITH_VIDEOS),
ChatCompletionRequest,
ChatCompletionResponse,
)

# this overrides _handle_message method of Gateway
def _handle_message(self, messages):
images = []
messages_dicts = []
if isinstance(messages, str):
prompt = messages

Check warning on line 605 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L605

Added line #L605 was not covered by tests
else:
messages_dict = {}
system_prompt = ""
prompt = ""
for message in messages:
msg_role = message["role"]
messages_dict = {}
if msg_role == "system":
system_prompt = message["content"]
elif msg_role == "user":
if type(message["content"]) == list:
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
messages_dict[msg_role] = text

Check warning on line 626 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L626

Added line #L626 was not covered by tests
else:
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
elif msg_role == "assistant":
messages_dict[msg_role] = message["content"]
messages_dicts.append(messages_dict)
else:
raise ValueError(f"Unknown role: {msg_role}")

Check warning on line 634 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L634

Added line #L634 was not covered by tests

if system_prompt:
prompt = system_prompt + "\n"
for messages_dict in messages_dicts:
for i, (role, message) in enumerate(messages_dict.items()):
if isinstance(message, tuple):
text, image_list = message
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if text:
prompt += text + "\n"
else:
if text:
prompt += role.upper() + ": " + text + "\n"

Check warning on line 649 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L648-L649

Added lines #L648 - L649 were not covered by tests
else:
prompt += role.upper() + ":"

Check warning on line 651 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L651

Added line #L651 was not covered by tests
for img in image_list:
# URL
if img.startswith("http://") or img.startswith("https://"):
response = requests.get(img)
image = Image.open(BytesIO(response.content)).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
# Local Path
elif os.path.exists(img):
image = Image.open(img).convert("RGBA")
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()

Check warning on line 665 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L661-L665

Added lines #L661 - L665 were not covered by tests
# Bytes
else:
img_b64_str = img

Check warning on line 668 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L668

Added line #L668 was not covered by tests

images.append(img_b64_str)
else:
if i == 0:
# do not add role for the very first message.
# this will be added by llava_server
if message:
prompt += role.upper() + ": " + message + "\n"
else:
if message:
prompt += role.upper() + ": " + message + "\n"

Check warning on line 679 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L678-L679

Added lines #L678 - L679 were not covered by tests
else:
prompt += role.upper() + ":"

Check warning on line 681 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L681

Added line #L681 was not covered by tests
if images:
return prompt, images
else:
return prompt

Check warning on line 685 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L685

Added line #L685 was not covered by tests

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = bool(data.get("stream", False))
if stream_opt:
print("[ MultimodalRAGWithVideosGateway ] stream=True not used, this has not support streaming yet!")
stream_opt = False

Check warning on line 692 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L691-L692

Added lines #L691 - L692 were not covered by tests
chat_request = ChatCompletionRequest.model_validate(data)
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
prompt_and_image = self._handle_message(chat_request.messages)
if isinstance(prompt_and_image, tuple):
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
prompt, images = prompt_and_image
cur_megaservice = self.lvm_megaservice
initial_inputs = {"prompt": prompt, "image": images[0]}
else:
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
prompt = prompt_and_image
cur_megaservice = self.megaservice
initial_inputs = {"text": prompt}

Check warning on line 705 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L703-L705

Added lines #L703 - L705 were not covered by tests

parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
result_dict, runtime_graph = await cur_megaservice.schedule(
initial_inputs=initial_inputs, llm_parameters=parameters
)
for node, response in result_dict.items():
# the last microservice in this megaservice is LVM.
# checking if LVM returns StreamingResponse
# Currently, LVM with LLAVA has not yet supported streaming.
# @TODO: Will need to test this once LVM with LLAVA supports streaming
if (
isinstance(response, StreamingResponse)
and node == runtime_graph.all_leaves()[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response

Check warning on line 729 in comps/cores/mega/gateway.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/gateway.py#L729

Added line #L729 was not covered by tests
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="multimodalragwithvideos", choices=choices, usage=usage)
215 changes: 215 additions & 0 deletions tests/cores/mega/test_multimodalrag_with_videos_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest
from typing import Union

import requests
from fastapi import Request

from comps import (
EmbedDoc,
EmbedMultimodalDoc,
LVMDoc,
LVMSearchedMultimodalDoc,
MultimodalDoc,
MultimodalRAGWithVideosGateway,
SearchedMultimodalDoc,
ServiceOrchestrator,
TextDoc,
opea_microservices,
register_microservice,
)


@register_microservice(name="mm_embedding", host="0.0.0.0", port=8083, endpoint="/v1/mm_embedding")
async def mm_embedding_add(request: MultimodalDoc) -> EmbedDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
res = {}
res["text"] = text
res["embedding"] = [0.12, 0.45]
return res


@register_microservice(name="mm_retriever", host="0.0.0.0", port=8084, endpoint="/v1/mm_retriever")
async def mm_retriever_add(request: EmbedMultimodalDoc) -> SearchedMultimodalDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
res = {}
res["retrieved_docs"] = []
res["initial_query"] = text
res["top_n"] = 1
res["metadata"] = [
{
"b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC",
"transcript_for_inference": "yellow image",
}
]
res["chat_template"] = "The caption of the image is: '{context}'. {question}"
return res


@register_microservice(name="lvm", host="0.0.0.0", port=8085, endpoint="/v1/lvm")
async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
if isinstance(request, LVMSearchedMultimodalDoc):
print("request is the output of multimodal retriever")
text = req_dict["initial_query"]
text += "opea project!"

else:
print("request is from user.")
text = req_dict["prompt"]
text = f"<image>\nUSER: {text}\nASSISTANT:"

res = {}
res["text"] = text
return res


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
cls.mm_embedding = opea_microservices["mm_embedding"]
cls.mm_retriever = opea_microservices["mm_retriever"]
cls.lvm = opea_microservices["lvm"]
cls.mm_embedding.start()
cls.mm_retriever.start()
cls.lvm.start()

cls.service_builder = ServiceOrchestrator()

cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add(
opea_microservices["lvm"]
)
cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever)
cls.service_builder.flow_to(cls.mm_retriever, cls.lvm)

cls.follow_up_query_service_builder = ServiceOrchestrator()
cls.follow_up_query_service_builder.add(cls.lvm)

cls.gateway = MultimodalRAGWithVideosGateway(
cls.service_builder, cls.follow_up_query_service_builder, port=9898
)

@classmethod
def tearDownClass(cls):
cls.mm_embedding.stop()
cls.mm_retriever.stop()
cls.lvm.stop()
cls.gateway.stop()

async def test_service_builder_schedule(self):
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
self.assertEqual(result_dict[self.lvm.name]["text"], "hello, opea project!")

async def test_follow_up_query_service_builder_schedule(self):
result_dict, _ = await self.follow_up_query_service_builder.schedule(
initial_inputs={"prompt": "chao, ", "image": "some image"}
)
# print(result_dict)
self.assertEqual(result_dict[self.lvm.name]["text"], "<image>\nUSER: chao, \nASSISTANT:")

def test_multimodal_rag_with_videos_gateway(self):
json_data = {"messages": "hello, "}
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
response = response.json()
self.assertEqual(response["choices"][-1]["message"]["content"], "hello, opea project!")

def test_follow_up_mm_rag_with_videos_gateway(self):
json_data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
],
"max_tokens": 300,
}
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
response = response.json()
self.assertEqual(
response["choices"][-1]["message"]["content"],
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
)

def test_handle_message(self):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
]
prompt, images = self.gateway._handle_message(messages)
self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n")

def test_handle_message_with_system_prompt(self):
messages = [
{"role": "system", "content": "System Prompt"},
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
]
prompt, images = self.gateway._handle_message(messages)
self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n")

async def test_handle_request(self):
json_data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hello, "},
{
"type": "image_url",
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
},
],
},
{"role": "assistant", "content": "opea project! "},
{"role": "user", "content": "chao, "},
],
"max_tokens": 300,
}
mock_request = Request(scope={"type": "http"})
mock_request._json = json_data
res = await self.gateway.handle_request(mock_request)
res = json.loads(res.json())
self.assertEqual(
res["choices"][-1]["message"]["content"],
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
)


if __name__ == "__main__":
unittest.main()