Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[for quick UT, not merge] #1056

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ USER user
COPY comps /home/user/comps

RUN pip install --no-cache-dir --upgrade pip setuptools && \
pip install --no-cache-dir -r /home/user/comps/llms/text-generation/tgi/requirements.txt
pip install --no-cache-dir -r /home/user/comps/llms/src/text-generation/requirements.txt

ENV PYTHONPATH=$PYTHONPATH:/home/user

WORKDIR /home/user/comps/llms/text-generation/tgi
WORKDIR /home/user/comps/llms/src/text-generation

ENTRYPOINT ["bash", "entrypoint.sh"]

Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,11 @@ To start the LLM microservice, you need to install python packages first.
pip install -r requirements.txt
```

### 1.2 Start LLM Service
### 1.2 Start 3rd-party TGI Service

```bash
export HF_TOKEN=${your_hf_api_token}
export LLM_MODEL_ID=${your_hf_llm_model}
docker run -p 8008:80 -v ./data:/data --name tgi_service --shm-size 1g ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $LLM_MODEL_ID
```

### 1.3 Verify the TGI Service

```bash
curl http://${your_ip}:8008/v1/chat/completions \
-X POST \
-d '{"model": ${LLM_MODEL_ID}, "messages": [{"role": "user", "content": "What is Deep Learning?"}], "max_tokens":17}' \
-H 'Content-Type: application/json'
```
Please refer to [3rd-party TGI](../../../3rd_parties/tgi/deployment/docker/) to start a LLM endpoint and verify.

### 1.4 Start LLM Service with Python Script
### 1.3 Start LLM Service with Python Script

```bash
export TGI_LLM_ENDPOINT="http://${your_ip}:8008"
Expand Down Expand Up @@ -73,8 +60,8 @@ docker run -d --name="llm-tgi-server" -p 9000:9000 --ipc=host -e http_proxy=$htt
### 2.4 Run Docker with Docker Compose (Option B)

```bash
cd text-generation/tgi
docker compose -f docker_compose_llm.yaml up -d
cd comps/llms/deployment/docker_compose/
docker compose -f text-generation_tgi.yaml up -d
```

## 🚀3. Consume LLM Service
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

pip --no-cache-dir install -r requirements-runtime.txt

python llm.py
python opea_llm_microservice.py
260 changes: 260 additions & 0 deletions comps/llms/src/text-generation/integrations/openai_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (C) 2024 Prediction Guard, Inc.
# SPDX-License-Identified: Apache-2.0

import asyncio
import os
from typing import Union

from fastapi.responses import StreamingResponse
from langchain_core.prompts import PromptTemplate
from openai import AsyncOpenAI

from comps import CustomLogger, LLMParamsDoc, OpeaComponent, SearchedDoc, ServiceType
from comps.cores.mega.utils import ConfigError, get_access_token, load_model_configs
from comps.cores.proto.api_protocol import ChatCompletionRequest

from .template import ChatTemplate

logger = CustomLogger("openai_llm")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
MODEL_NAME = os.getenv("LLM_MODEL_ID")
MODEL_CONFIGS = os.getenv("MODEL_CONFIGS")
DEFAULT_ENDPOINT = os.getenv("LLM_ENDPOINT", "http://localhost:8080")
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

# Validate and Load the models config if MODEL_CONFIGS is not null
configs_map = {}
if MODEL_CONFIGS:
try:
configs_map = load_model_configs(MODEL_CONFIGS)
except ConfigError as e:
logger.error(f"Failed to load model configurations: {e}")
raise ConfigError(f"Failed to load model configurations: {e}")


def get_llm_endpoint():
if not MODEL_CONFIGS:
return DEFAULT_ENDPOINT
try:
return configs_map.get(MODEL_NAME).get("endpoint")
except ConfigError as e:
logger.error(f"Input model {MODEL_NAME} not present in model_configs. Error {e}")
raise ConfigError(f"Input model {MODEL_NAME} not present in model_configs")


class OpenAILLM(OpeaComponent):
"""A specialized LLM component derived from OpeaComponent for interacting with TGI/vLLM services based on OpenAI API.

Attributes:
client (TGI/vLLM): An instance of the TGI/vLLM client for text generation.
"""

def __init__(self, name: str, description: str, config: dict = None):
super().__init__(name, ServiceType.LLM.name.lower(), description, config)
self.client = self._initialize_client()

def _initialize_client(self) -> AsyncOpenAI:
"""Initializes the AsyncOpenAI."""
access_token = (
get_access_token(TOKEN_URL, CLIENTID, CLIENT_SECRET) if TOKEN_URL and CLIENTID and CLIENT_SECRET else None
)
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}
llm_endpoint = get_llm_endpoint()
return AsyncOpenAI(api_key="EMPTY", base_url=llm_endpoint + "/v1", timeout=600, default_headers=headers)

def check_health(self) -> bool:
"""Checks the health of the TGI/vLLM LLM service.

Returns:
bool: True if the service is reachable and healthy, False otherwise.
"""

try:

async def send_simple_request():
response = await self.client.completions.create(model=MODEL_NAME, prompt="How are you?", max_tokens=4)
return response

response = asyncio.run(send_simple_request())
return response is not None
except Exception as e:
logger.error(e)
logger.error("Health check failed")
return False

def align_input(
self, input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc], prompt_template, input_variables
):
if isinstance(input, SearchedDoc):
if logflag:
logger.info("[ SearchedDoc ] input from retriever microservice")
prompt = input.initial_query
if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
if logflag:
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs, MODEL_NAME)

## use default ChatCompletionRequest parameters
new_input = ChatCompletionRequest(messages=prompt)

if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")

return prompt, new_input

elif isinstance(input, LLMParamsDoc):
if logflag:
logger.info("[ LLMParamsDoc ] input from rerank microservice")
prompt = input.query
if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.query)
else:
logger.info(
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents, input.model)

# convert to unified OpenAI /v1/chat/completions format
new_input = ChatCompletionRequest(
messages=prompt,
max_tokens=input.max_tokens,
top_p=input.top_p,
stream=input.streaming,
frequency_penalty=input.frequency_penalty,
temperature=input.temperature,
)

return prompt, new_input

else:
if logflag:
logger.info("[ ChatCompletionRequest ] input in opea format")

prompt = input.messages
if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.messages, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.messages)
else:
logger.info(
f"[ ChatCompletionRequest ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents, input.model)

return prompt, input

async def invoke(self, input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
"""Invokes the TGI/vLLM LLM service to generate output for the provided input.

Args:
input (Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]): The input text(s).
"""

prompt_template = None
input_variables = None
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables

if isinstance(input, ChatCompletionRequest) and not isinstance(input.messages, str):
if logflag:
logger.info("[ ChatCompletionRequest ] input in opea format")

if input.messages[0]["role"] == "system":
if "{context}" in input.messages[0]["content"]:
if input.documents is None or input.documents == []:
input.messages[0]["content"].format(context="")
else:
input.messages[0]["content"].format(context="\n".join(input.documents))
else:
if prompt_template:
system_prompt = prompt_template
if input_variables == ["context"]:
system_prompt = prompt_template.format(context="\n".join(input.documents))
else:
logger.info(
f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']"
)

input.messages.insert(0, {"role": "system", "content": system_prompt})

chat_completion = await self.client.chat.completions.create(
model=MODEL_NAME,
messages=input.messages,
frequency_penalty=input.frequency_penalty,
max_tokens=input.max_tokens,
n=input.n,
presence_penalty=input.presence_penalty,
response_format=input.response_format,
seed=input.seed,
stop=input.stop,
stream=input.stream,
stream_options=input.stream_options,
temperature=input.temperature,
top_p=input.top_p,
user=input.user,
)
"""TODO need validate following parameters for vllm
logit_bias=input.logit_bias,
logprobs=input.logprobs,
top_logprobs=input.top_logprobs,
service_tier=input.service_tier,
tools=input.tools,
tool_choice=input.tool_choice,
parallel_tool_calls=input.parallel_tool_calls,"""
else:
prompt, input = self.align_input(input, prompt_template, input_variables)
chat_completion = await self.client.completions.create(
model=MODEL_NAME,
prompt=prompt,
echo=input.echo,
frequency_penalty=input.frequency_penalty,
max_tokens=input.max_tokens,
n=input.n,
presence_penalty=input.presence_penalty,
seed=input.seed,
stop=input.stop,
stream=input.stream,
suffix=input.suffix,
temperature=input.temperature,
top_p=input.top_p,
user=input.user,
)
"""TODO need validate following parameters for vllm
best_of=input.best_of,
logit_bias=input.logit_bias,
logprobs=input.logprobs,"""

if input.stream:

async def stream_generator():
async for c in chat_completion:
if logflag:
logger.info(c)
chunk = c.model_dump_json()
if chunk not in ["<|im_end|>", "<|endoftext|>"]:
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
if logflag:
logger.info(chat_completion)
return chat_completion
74 changes: 74 additions & 0 deletions comps/llms/src/text-generation/opea_llm_microservice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import time
from typing import Union

from integrations.openai_llm import OpenAILLM

from comps import (
CustomLogger,
LLMParamsDoc,
OpeaComponentController,
SearchedDoc,
ServiceType,
opea_microservices,
register_microservice,
register_statistics,
statistics_dict,
)
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm")
logflag = os.getenv("LOGFLAG", False)

# Initialize OpeaComponentController
controller = OpeaComponentController()

# Register components
try:
openai_llm = OpenAILLM(
name="OpenAILLM",
description="OpenAI LLM Service",
)

# Register components with the controller
controller.register(openai_llm)

# Discover and activate a healthy component
controller.discover_and_activate()
except Exception as e:
logger.error(f"Failed to initialize components: {e}")


@register_microservice(
name="opea_service@llm",
service_type=ServiceType.LLM,
endpoint="/v1/chat/completions",
host="0.0.0.0",
port=9000,
)
@register_statistics(names=["opea_service@llm"])
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
start = time.time()

# Log the input if logging is enabled
if logflag:
logger.info(input)

try:
# Use the controller to invoke the active component
response = await controller.invoke(input)
# Record statistics
statistics_dict["opea_service@llm"].append_latency(time.time() - start, None)
return response

except Exception as e:
logger.error(f"Error during LLM invocation: {e}")
raise


if __name__ == "__main__":
logger.info("OPEA LLM Microservice is starting...")
opea_microservices["opea_service@llm"].start()
Loading
Loading