Skip to content

Commit

Permalink
Support llava-next using TGI (#423)
Browse files Browse the repository at this point in the history
Signed-off-by: lvliang-intel <[email protected]>
  • Loading branch information
lvliang-intel authored Aug 8, 2024
1 parent 76877c1 commit e156101
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 0 deletions.
44 changes: 44 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,47 @@ async def handle_request(self, request: Request):
)
)
return ChatCompletionResponse(model="faqgen", choices=choices, usage=usage)


class VisualQnAGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice, host, port, str(MegaServiceEndpoint.VISUAL_QNA), ChatCompletionRequest, ChatCompletionResponse
)

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", True)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LVM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)
19 changes: 19 additions & 0 deletions comps/lvms/Dockerfile_tgi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

FROM python:3.11-slim

# Set environment variables
ENV LANG=en_US.UTF-8

COPY comps /home/comps

RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r /home/comps/lvms/requirements.txt

ENV PYTHONPATH=$PYTHONPATH:/home

WORKDIR /home/comps/lvms

ENTRYPOINT ["python", "lvm_tgi.py"]

47 changes: 47 additions & 0 deletions comps/lvms/lvm_tgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import time

from huggingface_hub import InferenceClient

from comps import (
LVMDoc,
ServiceType,
TextDoc,
opea_microservices,
register_microservice,
register_statistics,
statistics_dict,
)


@register_microservice(
name="opea_service@lvm_tgi",
service_type=ServiceType.LVM,
endpoint="/v1/lvm",
host="0.0.0.0",
port=9399,
input_datatype=LVMDoc,
output_datatype=TextDoc,
)
@register_statistics(names=["opea_service@lvm_tgi"])
async def lvm(request: LVMDoc):
start = time.time()
img_b64_str = request.image
prompt = request.prompt
max_new_tokens = request.max_new_tokens

image = f"data:image/png;base64,{img_b64_str}"
image_prompt = f"![]({image})\nUSER: {prompt}\nASSISTANT:"
generated_str = lvm_client.text_generation(image_prompt, max_new_tokens=max_new_tokens)
statistics_dict["opea_service@lvm_tgi"].append_latency(time.time() - start, None)
return TextDoc(text=generated_str)


if __name__ == "__main__":
lvm_endpoint = os.getenv("LVM_ENDPOINT", "http://localhost:8399")
lvm_client = InferenceClient(lvm_endpoint)
print("[LVM] LVM initialized.")
opea_microservices["opea_service@lvm_tgi"].start()
1 change: 1 addition & 0 deletions comps/lvms/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
datasets
docarray[full]
fastapi
huggingface_hub
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Expand Down
57 changes: 57 additions & 0 deletions tests/test_lvms_tgi_llava_next.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

set -xe

WORKPATH=$(dirname "$PWD")
ip_address=$(hostname -I | awk '{print $1}')

function build_docker_images() {
cd $WORKPATH
echo $(pwd)
git clone https://github.com/yuanwu2017/tgi-gaudi.git && cd tgi-gaudi && git checkout v2.0.4
docker build -t opea/llava-tgi:latest .
cd ..
docker build --no-cache -t opea/lvm-tgi:latest -f comps/lvms/Dockerfile_tgi .
}

function start_service() {
unset http_proxy
model="llava-hf/llava-v1.6-mistral-7b-hf"
docker run -d --name="test-comps-lvm-llava-tgi" -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 8399:80 --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e SKIP_TOKENIZER_IN_TGI=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host opea/llava-tgi:latest --model-id $model --max-input-tokens 4096 --max-total-tokens 8192
docker run -d --name="test-comps-lvm-tgi" -e LVM_ENDPOINT=http://$ip_address:8399 -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 9399:9399 --ipc=host opea/lvm-tgi:latest
sleep 3m
}

function validate_microservice() {
result=$(http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "prompt":"What is this?"}' -H 'Content-Type: application/json')
if [[ $result == *"yellow"* ]]; then
echo "Result correct."
else
echo "Result wrong."
exit 1
fi

}

function stop_docker() {
cid=$(docker ps -aq --filter "name=test-comps-lvm*")
if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi
}

function main() {

stop_docker

build_docker_images
start_service

validate_microservice

stop_docker
echo y | docker system prune

}

main

0 comments on commit e156101

Please sign in to comment.