Skip to content

Commit

Permalink
Support streaming output for LVM microservice (opea-project#430)
Browse files Browse the repository at this point in the history
* Support llava-next using TGI

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test script

Signed-off-by: lvliang-intel <[email protected]>

* fix ci issues

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support streeaming output

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: lvliang-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lvliang-intel and pre-commit-ci[bot] authored Aug 9, 2024
1 parent f8d45e5 commit c5a0344
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 16 deletions.
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
SearchQnAGateway,
AudioQnAGateway,
FaqGenGateway,
VisualQnAGateway,
)

# Telemetry
Expand Down
44 changes: 35 additions & 9 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import base64

import requests
from fastapi import Request
from fastapi.responses import StreamingResponse

Expand Down Expand Up @@ -75,6 +78,8 @@ def _handle_message(self, messages):
prompt = messages
else:
messages_dict = {}
system_prompt = ""
prompt = ""
for message in messages:
msg_role = message["role"]
if msg_role == "system":
Expand All @@ -84,20 +89,41 @@ def _handle_message(self, messages):
text = ""
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
text += "\n".join(text_list)
messages_dict[msg_role] = text
image_list = [
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
]
if image_list:
messages_dict[msg_role] = (text, image_list)
else:
messages_dict[msg_role] = text
else:
messages_dict[msg_role] = message["content"]
elif msg_role == "assistant":
messages_dict[msg_role] = message["content"]
else:
raise ValueError(f"Unknown role: {msg_role}")
prompt = system_prompt + "\n"
if system_prompt:
prompt = system_prompt + "\n"
images = []
for role, message in messages_dict.items():
if message:
prompt += role + ": " + message + "\n"
if isinstance(message, tuple):
text, image_list = message
if text:
prompt += role + ": " + text + "\n"
else:
prompt += role + ":"
for img in image_list:
response = requests.get(img)
images.append(base64.b64encode(response.content).decode("utf-8"))
else:
prompt += role + ":"
return prompt
if message:
prompt += role + ": " + message + "\n"
else:
prompt += role + ":"
if images:
return prompt, images
else:
return prompt


class ChatQnAGateway(Gateway):
Expand Down Expand Up @@ -449,9 +475,9 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888):

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
prompt, images = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
Expand All @@ -461,7 +487,7 @@ async def handle_request(self, request: Request):
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LVM.
Expand Down
5 changes: 4 additions & 1 deletion comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ async def execute(
if inputs.get(field) != value:
inputs[field] = value

if self.services[cur_node].service_type == ServiceType.LLM and llm_parameters.streaming:
if (
self.services[cur_node].service_type == ServiceType.LLM
or self.services[cur_node].service_type == ServiceType.LVM
) and llm_parameters.streaming:
# Still leave to sync requests.post for StreamingResponse
response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
Expand Down
6 changes: 6 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,9 @@ class LVMDoc(BaseDoc):
image: str
prompt: str
max_new_tokens: conint(ge=0, le=1024) = 512
top_k: int = 10
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
repetition_penalty: float = 1.03
streaming: bool = False
52 changes: 46 additions & 6 deletions comps/lvms/lvm_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
import time

from huggingface_hub import InferenceClient
from fastapi.responses import StreamingResponse
from huggingface_hub import AsyncInferenceClient

from comps import (
LVMDoc,
Expand All @@ -29,19 +30,58 @@
@register_statistics(names=["opea_service@lvm_tgi"])
async def lvm(request: LVMDoc):
start = time.time()
stream_gen_time = []
img_b64_str = request.image
prompt = request.prompt
max_new_tokens = request.max_new_tokens
streaming = request.streaming
repetition_penalty = request.repetition_penalty
temperature = request.temperature
top_k = request.top_k
top_p = request.top_p

image = f"data:image/png;base64,{img_b64_str}"
image_prompt = f"![]({image})\nUSER: {prompt}\nASSISTANT:"
generated_str = lvm_client.text_generation(image_prompt, max_new_tokens=max_new_tokens)
statistics_dict["opea_service@lvm_tgi"].append_latency(time.time() - start, None)
return TextDoc(text=generated_str)
image_prompt = f"![]({image})\n{prompt}\nASSISTANT:"

if streaming:

async def stream_generator():
chat_response = ""
text_generation = await lvm_client.text_generation(
prompt=prompt,
stream=streaming,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
async for text in text_generation:
stream_gen_time.append(time.time() - start)
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
print(f"[llm - chat_stream] chunk:{chunk_repr}")
yield f"data: {chunk_repr}\n\n"
print(f"[llm - chat_stream] stream response: {chat_response}")
statistics_dict["opea_service@lvm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0])
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
generated_str = await lvm_client.text_generation(
image_prompt,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
statistics_dict["opea_service@lvm_tgi"].append_latency(time.time() - start, None)
return TextDoc(text=generated_str)


if __name__ == "__main__":
lvm_endpoint = os.getenv("LVM_ENDPOINT", "http://localhost:8399")
lvm_client = InferenceClient(lvm_endpoint)
lvm_client = AsyncInferenceClient(lvm_endpoint)
print("[LVM] LVM initialized.")
opea_microservices["opea_service@lvm_tgi"].start()

0 comments on commit c5a0344

Please sign in to comment.