diff --git a/comps/__init__.py b/comps/__init__.py index b483f46a7..cb7ed7a28 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -37,6 +37,7 @@ SearchQnAGateway, AudioQnAGateway, FaqGenGateway, + VisualQnAGateway, ) # Telemetry diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index dd05453fc..324f7081e 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -1,6 +1,9 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import base64 + +import requests from fastapi import Request from fastapi.responses import StreamingResponse @@ -75,6 +78,8 @@ def _handle_message(self, messages): prompt = messages else: messages_dict = {} + system_prompt = "" + prompt = "" for message in messages: msg_role = message["role"] if msg_role == "system": @@ -84,20 +89,41 @@ def _handle_message(self, messages): text = "" text_list = [item["text"] for item in message["content"] if item["type"] == "text"] text += "\n".join(text_list) - messages_dict[msg_role] = text + image_list = [ + item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url" + ] + if image_list: + messages_dict[msg_role] = (text, image_list) + else: + messages_dict[msg_role] = text else: messages_dict[msg_role] = message["content"] elif msg_role == "assistant": messages_dict[msg_role] = message["content"] else: raise ValueError(f"Unknown role: {msg_role}") - prompt = system_prompt + "\n" + if system_prompt: + prompt = system_prompt + "\n" + images = [] for role, message in messages_dict.items(): - if message: - prompt += role + ": " + message + "\n" + if isinstance(message, tuple): + text, image_list = message + if text: + prompt += role + ": " + text + "\n" + else: + prompt += role + ":" + for img in image_list: + response = requests.get(img) + images.append(base64.b64encode(response.content).decode("utf-8")) else: - prompt += role + ":" - return prompt + if message: + prompt += role + ": " + message + "\n" + else: + prompt += role + ":" + if images: + return prompt, images + else: + return prompt class ChatQnAGateway(Gateway): @@ -449,9 +475,9 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): async def handle_request(self, request: Request): data = await request.json() - stream_opt = data.get("stream", True) + stream_opt = data.get("stream", False) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + prompt, images = self._handle_message(chat_request.messages) parameters = LLMParams( max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -461,7 +487,7 @@ async def handle_request(self, request: Request): streaming=stream_opt, ) result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"query": prompt}, llm_parameters=parameters + initial_inputs={"prompt": prompt, "image": images[0]}, llm_parameters=parameters ) for node, response in result_dict.items(): # Here it suppose the last microservice in the megaservice is LVM. diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index d4f3ac9b7..616af41c8 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -117,7 +117,10 @@ async def execute( if inputs.get(field) != value: inputs[field] = value - if self.services[cur_node].service_type == ServiceType.LLM and llm_parameters.streaming: + if ( + self.services[cur_node].service_type == ServiceType.LLM + or self.services[cur_node].service_type == ServiceType.LVM + ) and llm_parameters.streaming: # Still leave to sync requests.post for StreamingResponse response = requests.post( url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000 diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 6a4e55d4c..9e07d618d 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -130,3 +130,9 @@ class LVMDoc(BaseDoc): image: str prompt: str max_new_tokens: conint(ge=0, le=1024) = 512 + top_k: int = 10 + top_p: float = 0.95 + typical_p: float = 0.95 + temperature: float = 0.01 + repetition_penalty: float = 1.03 + streaming: bool = False diff --git a/comps/lvms/lvm_tgi.py b/comps/lvms/lvm_tgi.py index 7a51b562c..b7383fa0c 100644 --- a/comps/lvms/lvm_tgi.py +++ b/comps/lvms/lvm_tgi.py @@ -4,7 +4,8 @@ import os import time -from huggingface_hub import InferenceClient +from fastapi.responses import StreamingResponse +from huggingface_hub import AsyncInferenceClient from comps import ( LVMDoc, @@ -29,19 +30,58 @@ @register_statistics(names=["opea_service@lvm_tgi"]) async def lvm(request: LVMDoc): start = time.time() + stream_gen_time = [] img_b64_str = request.image prompt = request.prompt max_new_tokens = request.max_new_tokens + streaming = request.streaming + repetition_penalty = request.repetition_penalty + temperature = request.temperature + top_k = request.top_k + top_p = request.top_p image = f"data:image/png;base64,{img_b64_str}" - image_prompt = f"![]({image})\nUSER: {prompt}\nASSISTANT:" - generated_str = lvm_client.text_generation(image_prompt, max_new_tokens=max_new_tokens) - statistics_dict["opea_service@lvm_tgi"].append_latency(time.time() - start, None) - return TextDoc(text=generated_str) + image_prompt = f"![]({image})\n{prompt}\nASSISTANT:" + + if streaming: + + async def stream_generator(): + chat_response = "" + text_generation = await lvm_client.text_generation( + prompt=prompt, + stream=streaming, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + async for text in text_generation: + stream_gen_time.append(time.time() - start) + chat_response += text + chunk_repr = repr(text.encode("utf-8")) + print(f"[llm - chat_stream] chunk:{chunk_repr}") + yield f"data: {chunk_repr}\n\n" + print(f"[llm - chat_stream] stream response: {chat_response}") + statistics_dict["opea_service@lvm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0]) + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + generated_str = await lvm_client.text_generation( + image_prompt, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + statistics_dict["opea_service@lvm_tgi"].append_latency(time.time() - start, None) + return TextDoc(text=generated_str) if __name__ == "__main__": lvm_endpoint = os.getenv("LVM_ENDPOINT", "http://localhost:8399") - lvm_client = InferenceClient(lvm_endpoint) + lvm_client = AsyncInferenceClient(lvm_endpoint) print("[LVM] LVM initialized.") opea_microservices["opea_service@lvm_tgi"].start()