Skip to content

Commit

Permalink
Multiple models support for LLM TGI (#835)
Browse files Browse the repository at this point in the history
* Update gateway and docarray from mega and proto services to have model field for ChatQnAGateway and LLMParams respectively

* Add load_model_configs method in utils.py to validate and load the model_configs

* Update llms text-generation tgi file (llm.py) to support multiple models. Uses load_model_configs method from utils

* Update llms text-generation tgi template to add different templates for different models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed llm_endpoint empty string issue on error scenario

Signed-off-by: sgurunat <[email protected]>

* Function to get llm_endpoint and keep the code clean

Signed-off-by: sgurunat <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: sgurunat <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sgurunat and pre-commit-ci[bot] authored Nov 13, 2024
1 parent 9e471a9 commit e879366
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 23 deletions.
1 change: 1 addition & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
model=chat_request.model if chat_request.model else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
Expand Down
39 changes: 39 additions & 0 deletions comps/cores/mega/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import ipaddress
import json
import multiprocessing
import os
import random
Expand Down Expand Up @@ -187,6 +188,44 @@ def _check_bind(port):
return _random_port()


class ConfigError(Exception):
"""Custom exception for configuration errors."""

pass


def load_model_configs(model_configs: str) -> dict:
"""Load and validate the model configurations .
If valid, return the configuration for the specified model name.
"""
logger = CustomLogger("models_loader")
try:
configs = json.loads(model_configs)
if not isinstance(configs, list) or not configs:
raise ConfigError("MODEL_CONFIGS must be a non-empty JSON array.")
required_keys = {"model_name", "displayName", "endpoint", "minToken", "maxToken"}
configs_map = {}
for config in configs:
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ConfigError(f"Missing required configuration fields: {missing_keys}")
empty_keys = [key for key in required_keys if not config.get(key)]
if empty_keys:
raise ConfigError(f"Empty values found for configuration fields: {empty_keys}")
model_name = config["model_name"]
configs_map[model_name] = config
if not configs_map:
raise ConfigError("No valid configurations found.")
return configs_map
except json.JSONDecodeError:
logger.error("Error parsing MODEL_CONFIGS environment variable as JSON.")
raise ConfigError("MODEL_CONFIGS is not valid JSON.")
except ConfigError as e:
logger.error(str(e))
raise


def get_access_token(token_url: str, client_id: str, client_secret: str) -> str:
"""Get access token using OAuth client credentials flow."""
logger = CustomLogger("tgi_or_tei_service_auth")
Expand Down
1 change: 1 addition & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def chat_template_must_contain_variables(cls, v):


class LLMParams(BaseDoc):
model: Optional[str] = None
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
Expand Down
31 changes: 25 additions & 6 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,37 @@
register_statistics,
statistics_dict,
)
from comps.cores.mega.utils import get_access_token
from comps.cores.mega.utils import ConfigError, get_access_token, load_model_configs
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm_tgi")
logflag = os.getenv("LOGFLAG", False)

# Environment variables
MODEL_CONFIGS = os.getenv("MODEL_CONFIGS")
DEFAULT_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
TOKEN_URL = os.getenv("TOKEN_URL")
CLIENTID = os.getenv("CLIENTID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
# Validate and Load the models config if MODEL_CONFIGS is not null
configs_map = {}
if MODEL_CONFIGS:
try:
configs_map = load_model_configs(MODEL_CONFIGS)
except ConfigError as e:
logger.error(f"Failed to load model configurations: {e}")
raise ConfigError(f"Failed to load model configurations: {e}")


def get_llm_endpoint(model):
if not MODEL_CONFIGS:
return DEFAULT_ENDPOINT
try:
return configs_map.get(model).get("endpoint")
except ConfigError as e:
logger.error(f"Input model {model} not present in model_configs. Error {e}")
raise ConfigError(f"Input model {model} not present in model_configs")


@register_microservice(
Expand All @@ -54,7 +73,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche
headers = {}
if access_token:
headers = {"Authorization": f"Bearer {access_token}"}

llm_endpoint = get_llm_endpoint(input.model)
llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers)

prompt_template = None
Expand All @@ -73,7 +92,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche
docs = [doc.text for doc in input.retrieved_docs]
if logflag:
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs)
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs, input.model)
# use default llm parameters for inferencing
new_input = LLMParamsDoc(query=prompt)
if logflag:
Expand Down Expand Up @@ -126,7 +145,7 @@ async def stream_generator():
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents, input.model)

text_generation = await llm.text_generation(
prompt=prompt,
Expand Down Expand Up @@ -182,7 +201,7 @@ async def stream_generator():
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents)
prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents, input.model)

chat_completion = client.completions.create(
model="tgi",
Expand Down
41 changes: 24 additions & 17 deletions comps/llms/text-generation/tgi/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@

class ChatTemplate:
@staticmethod
def generate_rag_prompt(question, documents):
def generate_rag_prompt(question, documents, model):
context_str = "\n".join(documents)
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
if model == "meta-llama/Meta-Llama-3.1-70B-Instruct" or model == "meta-llama/Meta-Llama-3.1-8B-Instruct":
template = """
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
### 搜索结果:{context}
### 问题:{question}
### 回答:
"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
else:
template = """
### You are a helpful, respectful and honest assistant to help the user with questions. \
Please refer to the search results obtained from the local knowledge base. \
But be careful to not incorporate the information that you think is not relevant to the question. \
If you don't know the answer to a question, please don't share false information. \n
### Search results: {context} \n
### Question: {question} \n
### Answer:
"""
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
template = """
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
### 搜索结果:{context}
### 问题:{question}
### 回答:
"""
else:
template = """
### You are a helpful, respectful and honest assistant to help the user with questions. \
Please refer to the search results obtained from the local knowledge base. \
But be careful to not incorporate the information that you think is not relevant to the question. \
If you don't know the answer to a question, please don't share false information. \n
### Search results: {context} \n
### Question: {question} \n
### Answer:
"""
return template.format(context=context_str, question=question)

0 comments on commit e879366

Please sign in to comment.