From 550325d8cbce24a5c4af73bf0beb741a88582e9b Mon Sep 17 00:00:00 2001 From: sgurunat Date: Wed, 13 Nov 2024 12:50:15 +0530 Subject: [PATCH] vLLM support for DocSum (#885) * Add model parameter for DocSumGateway in gateway.py file Signed-off-by: sgurunat * Add langchain vllm support for DocSum along with authentication support for vllm endpoints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated docker_compose_llm.yaml and README file with vLLM information Signed-off-by: sgurunat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated docsum-vllm Dockerfile into llm-compose-cd.yaml under github workflows Signed-off-by: sgurunat * Updated llm-compose.yaml file to include vllm sumarization docker build Signed-off-by: sgurunat --------- Signed-off-by: sgurunat Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lvliang-intel --- .../docker/compose/llms-compose.yaml | 4 + comps/cores/mega/gateway.py | 2 + .../summarization/vllm/langchain/Dockerfile | 28 +++++ .../summarization/vllm/langchain/README.md | 112 +++++++++++++++++ .../summarization/vllm/langchain/__init__.py | 2 + .../vllm/langchain/docker_compose_llm.yaml | 44 +++++++ .../vllm/langchain/entrypoint.sh | 8 ++ .../llms/summarization/vllm/langchain/llm.py | 118 ++++++++++++++++++ .../vllm/langchain/requirements-runtime.txt | 1 + .../vllm/langchain/requirements.txt | 15 +++ 10 files changed, 334 insertions(+) create mode 100644 comps/llms/summarization/vllm/langchain/Dockerfile create mode 100644 comps/llms/summarization/vllm/langchain/README.md create mode 100644 comps/llms/summarization/vllm/langchain/__init__.py create mode 100644 comps/llms/summarization/vllm/langchain/docker_compose_llm.yaml create mode 100644 comps/llms/summarization/vllm/langchain/entrypoint.sh create mode 100644 comps/llms/summarization/vllm/langchain/llm.py create mode 100644 comps/llms/summarization/vllm/langchain/requirements-runtime.txt create mode 100644 comps/llms/summarization/vllm/langchain/requirements.txt diff --git a/.github/workflows/docker/compose/llms-compose.yaml b/.github/workflows/docker/compose/llms-compose.yaml index ff93075f6..91fbb46d4 100644 --- a/.github/workflows/docker/compose/llms-compose.yaml +++ b/.github/workflows/docker/compose/llms-compose.yaml @@ -58,6 +58,10 @@ services: build: dockerfile: comps/llms/text-generation/predictionguard/Dockerfile image: ${REGISTRY:-opea}/llm-textgen-predictionguard:${TAG:-latest} + llm-docsum-vllm: + build: + dockerfile: comps/llms/summarization/vllm/langchain/Dockerfile + image: ${REGISTRY:-opea}/llm-docsum-vllm:${TAG:-latest} llm-faqgen-vllm: build: dockerfile: comps/llms/faq-generation/vllm/langchain/Dockerfile diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index d27d2d708..1dc94074f 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -433,6 +433,8 @@ async def handle_request(self, request: Request): presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, + language=chat_request.language if chat_request.language else "auto", + model=chat_request.model if chat_request.model else None, ) result_dict, runtime_graph = await self.megaservice.schedule( initial_inputs={data["type"]: prompt}, llm_parameters=parameters diff --git a/comps/llms/summarization/vllm/langchain/Dockerfile b/comps/llms/summarization/vllm/langchain/Dockerfile new file mode 100644 index 000000000..3a1cd5a8f --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/Dockerfile @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.11-slim + +ARG ARCH="cpu" + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + libgl1-mesa-glx \ + libjemalloc-dev + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip setuptools && \ + if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \ + pip install --no-cache-dir -r /home/user/comps/llms/summarization/vllm/langchain/requirements.txt + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/llms/summarization/vllm/langchain + +ENTRYPOINT ["bash", "entrypoint.sh"] diff --git a/comps/llms/summarization/vllm/langchain/README.md b/comps/llms/summarization/vllm/langchain/README.md new file mode 100644 index 000000000..bdb8f9beb --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/README.md @@ -0,0 +1,112 @@ +# Document Summary vLLM Microservice + +This microservice leverages LangChain to implement summarization strategies and facilitate LLM inference using vLLM. +[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving, it delivers state-of-the-art serving throughput with a set of advanced features such as PagedAttention, Continuous batching and etc.. Besides GPUs, vLLM already supported [Intel CPUs](https://www.intel.com/content/www/us/en/products/overview.html) and [Gaudi accelerators](https://habana.ai/products). + +## 🚀1. Start Microservice with Python 🐍 (Option 1) + +To start the LLM microservice, you need to install python packages first. + +### 1.1 Install Requirements + +```bash +pip install -r requirements.txt +``` + +### 1.2 Start LLM Service + +```bash +export HF_TOKEN=${your_hf_api_token} +export LLM_MODEL_ID=${your_hf_llm_model} +docker run -p 8008:80 -v ./data:/data --name llm-docsum-vllm --shm-size 1g opea/vllm:hpu --model-id ${LLM_MODEL_ID} +``` + +### 1.3 Verify the vLLM Service + +```bash +curl http://${your_ip}:8008/v1/chat/completions \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "user", "content": "What is Deep Learning? "}]}' +``` + +### 1.4 Start LLM Service with Python Script + +```bash +export vLLM_ENDPOINT="http://${your_ip}:8008" +python llm.py +``` + +## 🚀2. Start Microservice with Docker 🐳 (Option 2) + +If you start an LLM microservice with docker, the `docker_compose_llm.yaml` file will automatically start a vLLM/vLLM service with docker. + +To setup or build the vLLM image follow the instructions provided in [vLLM Gaudi](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/text-generation/vllm/langchain#22-vllm-on-gaudi) + +### 2.1 Setup Environment Variables + +In order to start vLLM and LLM services, you need to setup the following environment variables first. + +```bash +export HF_TOKEN=${your_hf_api_token} +export vLLM_ENDPOINT="http://${your_ip}:8008" +export LLM_MODEL_ID=${your_hf_llm_model} +``` + +### 2.2 Build Docker Image + +```bash +cd ../../../../../ +docker build -t opea/llm-docsum-vllm:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/llms/summarization/vllm/langchain/Dockerfile . +``` + +To start a docker container, you have two options: + +- A. Run Docker with CLI +- B. Run Docker with Docker Compose + +You can choose one as needed. + +### 2.3 Run Docker with CLI (Option A) + +```bash +docker run -d --name="llm-docsum-vllm-server" -p 9000:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e vLLM_ENDPOINT=$vLLM_ENDPOINT -e HF_TOKEN=$HF_TOKEN opea/llm-docsum-vllm:latest +``` + +### 2.4 Run Docker with Docker Compose (Option B) + +```bash +docker compose -f docker_compose_llm.yaml up -d +``` + +## 🚀3. Consume LLM Service + +### 3.1 Check Service Status + +```bash +curl http://${your_ip}:9000/v1/health_check\ + -X GET \ + -H 'Content-Type: application/json' +``` + +### 3.2 Consume LLM Service + +```bash +# Enable streaming to receive a streaming response. By default, this is set to True. +curl http://${your_ip}:9000/v1/chat/docsum \ + -X POST \ + -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en"}' \ + -H 'Content-Type: application/json' + +# Disable streaming to receive a non-streaming response. +curl http://${your_ip}:9000/v1/chat/docsum \ + -X POST \ + -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "streaming":false}' \ + -H 'Content-Type: application/json' + +# Use Chinese mode. By default, language is set to "en" +curl http://${your_ip}:9000/v1/chat/docsum \ + -X POST \ + -d '{"query":"2024年9月26日,北京——今日,英特尔正式发布英特尔® 至强® 6性能核处理器(代号Granite Rapids),为AI、数据分析、科学计算等计算密集型业务提供卓越性能。", "max_tokens":32, "language":"zh", "streaming":false}' \ + -H 'Content-Type: application/json' +``` diff --git a/comps/llms/summarization/vllm/langchain/__init__.py b/comps/llms/summarization/vllm/langchain/__init__.py new file mode 100644 index 000000000..916f3a44b --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/comps/llms/summarization/vllm/langchain/docker_compose_llm.yaml b/comps/llms/summarization/vllm/langchain/docker_compose_llm.yaml new file mode 100644 index 000000000..8cc13e318 --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/docker_compose_llm.yaml @@ -0,0 +1,44 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +version: "3.8" + +services: + vllm-service: + image: opea/vllm:hpu + container_name: vllm-gaudi-server + ports: + - "8008:80" + volumes: + - "./data:/data" + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HF_TOKEN: ${HF_TOKEN} + HABANA_VISIBLE_DEVICES: all + OMPI_MCA_btl_vader_single_copy_mechanism: none + LLM_MODEL_ID: ${LLM_MODEL_ID} + runtime: habana + cap_add: + - SYS_NICE + ipc: host + command: --enforce-eager --model $LLM_MODEL_ID --tensor-parallel-size 1 --host 0.0.0.0 --port 80 + llm: + image: opea/llm-docsum-vllm:latest + container_name: llm-docsum-vllm-server + ports: + - "9000:9000" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + vLLM_ENDPOINT: ${vLLM_ENDPOINT} + HUGGINGFACEHUB_API_TOKEN: ${HF_TOKEN} + LLM_MODEL_ID: ${LLM_MODEL_ID} + restart: unless-stopped + +networks: + default: + driver: bridge diff --git a/comps/llms/summarization/vllm/langchain/entrypoint.sh b/comps/llms/summarization/vllm/langchain/entrypoint.sh new file mode 100644 index 000000000..d60eddd36 --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/entrypoint.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +pip --no-cache-dir install -r requirements-runtime.txt + +python llm.py diff --git a/comps/llms/summarization/vllm/langchain/llm.py b/comps/llms/summarization/vllm/langchain/llm.py new file mode 100644 index 000000000..9f60d8b96 --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/llm.py @@ -0,0 +1,118 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +from fastapi.responses import StreamingResponse +from langchain.chains.summarize import load_summarize_chain +from langchain.docstore.document import Document +from langchain.prompts import PromptTemplate +from langchain.text_splitter import CharacterTextSplitter +from langchain_community.llms import VLLMOpenAI + +from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice +from comps.cores.mega.utils import get_access_token + +logger = CustomLogger("llm_docsum") +logflag = os.getenv("LOGFLAG", False) + +# Environment variables +TOKEN_URL = os.getenv("TOKEN_URL") +CLIENTID = os.getenv("CLIENTID") +CLIENT_SECRET = os.getenv("CLIENT_SECRET") +MODEL_ID = os.getenv("LLM_MODEL_ID", None) + +templ_en = """Write a concise summary of the following: +"{text}" +CONCISE SUMMARY:""" + +templ_zh = """请简要概括以下内容: +"{text}" +概况:""" + + +def post_process_text(text: str): + if text == " ": + return "data: @#$\n\n" + if text == "\n": + return "data:
\n\n" + if text.isspace(): + return None + new_text = text.replace(" ", "@#$") + return f"data: {new_text}\n\n" + + +@register_microservice( + name="opea_service@llm_docsum", + service_type=ServiceType.LLM, + endpoint="/v1/chat/docsum", + host="0.0.0.0", + port=9000, +) +async def llm_generate(input: LLMParamsDoc): + if logflag: + logger.info(input) + if input.language in ["en", "auto"]: + templ = templ_en + elif input.language in ["zh"]: + templ = templ_zh + else: + raise NotImplementedError('Please specify the input language in "en", "zh", "auto"') + + PROMPT = PromptTemplate.from_template(templ) + + if logflag: + logger.info("After prompting:") + logger.info(PROMPT) + + access_token = ( + get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None + ) + headers = {} + if access_token: + headers = {"Authorization": f"Bearer {access_token}"} + llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8080") + model = input.model if input.model else os.getenv("LLM_MODEL_ID") + llm = VLLMOpenAI( + openai_api_key="EMPTY", + openai_api_base=llm_endpoint + "/v1", + model_name=model, + default_headers=headers, + max_tokens=input.max_tokens, + top_p=input.top_p, + streaming=input.streaming, + temperature=input.temperature, + presence_penalty=input.repetition_penalty, + ) + llm_chain = load_summarize_chain(llm=llm, prompt=PROMPT) + texts = text_splitter.split_text(input.query) + + # Create multiple documents + docs = [Document(page_content=t) for t in texts] + + if input.streaming: + + async def stream_generator(): + from langserve.serialization import WellKnownLCSerializer + + _serializer = WellKnownLCSerializer() + async for chunk in llm_chain.astream_log(docs): + data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8") + if logflag: + logger.info(data) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + response = await llm_chain.ainvoke(docs) + response = response["output_text"] + if logflag: + logger.info(response) + return GeneratedDoc(text=response, prompt=input.query) + + +if __name__ == "__main__": + # Split text + text_splitter = CharacterTextSplitter() + opea_microservices["opea_service@llm_docsum"].start() diff --git a/comps/llms/summarization/vllm/langchain/requirements-runtime.txt b/comps/llms/summarization/vllm/langchain/requirements-runtime.txt new file mode 100644 index 000000000..225adde27 --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/requirements-runtime.txt @@ -0,0 +1 @@ +langserve diff --git a/comps/llms/summarization/vllm/langchain/requirements.txt b/comps/llms/summarization/vllm/langchain/requirements.txt new file mode 100644 index 000000000..e074ba8c8 --- /dev/null +++ b/comps/llms/summarization/vllm/langchain/requirements.txt @@ -0,0 +1,15 @@ +docarray[full] +fastapi +huggingface_hub +langchain #==0.1.12 +langchain-huggingface +langchain-openai +langchain_community +langchainhub +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +prometheus-fastapi-instrumentator +shortuuid +transformers +uvicorn