From e879366cf8ebfebf8cdce39d36264097184ea01e Mon Sep 17 00:00:00 2001 From: sgurunat Date: Wed, 13 Nov 2024 14:41:43 +0530 Subject: [PATCH] Multiple models support for LLM TGI (#835) * Update gateway and docarray from mega and proto services to have model field for ChatQnAGateway and LLMParams respectively * Add load_model_configs method in utils.py to validate and load the model_configs * Update llms text-generation tgi file (llm.py) to support multiple models. Uses load_model_configs method from utils * Update llms text-generation tgi template to add different templates for different models * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed llm_endpoint empty string issue on error scenario Signed-off-by: sgurunat * Function to get llm_endpoint and keep the code clean Signed-off-by: sgurunat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: sgurunat Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/mega/gateway.py | 1 + comps/cores/mega/utils.py | 39 ++++++++++++++++++++ comps/cores/proto/docarray.py | 1 + comps/llms/text-generation/tgi/llm.py | 31 ++++++++++++---- comps/llms/text-generation/tgi/template.py | 41 +++++++++++++--------- 5 files changed, 90 insertions(+), 23 deletions(-) diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 1dc94074f..dc6b076bc 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -220,6 +220,7 @@ async def handle_request(self, request: Request): repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, chat_template=chat_request.chat_template if chat_request.chat_template else None, + model=chat_request.model if chat_request.model else None, ) retriever_parameters = RetrieverParms( search_type=chat_request.search_type if chat_request.search_type else "similarity", diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index db23f023a..e5b2df4f5 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import ipaddress +import json import multiprocessing import os import random @@ -187,6 +188,44 @@ def _check_bind(port): return _random_port() +class ConfigError(Exception): + """Custom exception for configuration errors.""" + + pass + + +def load_model_configs(model_configs: str) -> dict: + """Load and validate the model configurations . + + If valid, return the configuration for the specified model name. + """ + logger = CustomLogger("models_loader") + try: + configs = json.loads(model_configs) + if not isinstance(configs, list) or not configs: + raise ConfigError("MODEL_CONFIGS must be a non-empty JSON array.") + required_keys = {"model_name", "displayName", "endpoint", "minToken", "maxToken"} + configs_map = {} + for config in configs: + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + raise ConfigError(f"Missing required configuration fields: {missing_keys}") + empty_keys = [key for key in required_keys if not config.get(key)] + if empty_keys: + raise ConfigError(f"Empty values found for configuration fields: {empty_keys}") + model_name = config["model_name"] + configs_map[model_name] = config + if not configs_map: + raise ConfigError("No valid configurations found.") + return configs_map + except json.JSONDecodeError: + logger.error("Error parsing MODEL_CONFIGS environment variable as JSON.") + raise ConfigError("MODEL_CONFIGS is not valid JSON.") + except ConfigError as e: + logger.error(str(e)) + raise + + def get_access_token(token_url: str, client_id: str, client_secret: str) -> str: """Get access token using OAuth client credentials flow.""" logger = CustomLogger("tgi_or_tei_service_auth") diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 712b461b2..490e7a9a8 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -213,6 +213,7 @@ def chat_template_must_contain_variables(cls, v): class LLMParams(BaseDoc): + model: Optional[str] = None max_tokens: int = 1024 max_new_tokens: int = 1024 top_k: int = 10 diff --git a/comps/llms/text-generation/tgi/llm.py b/comps/llms/text-generation/tgi/llm.py index c0d4ed311..a825a5d2a 100644 --- a/comps/llms/text-generation/tgi/llm.py +++ b/comps/llms/text-generation/tgi/llm.py @@ -22,18 +22,37 @@ register_statistics, statistics_dict, ) -from comps.cores.mega.utils import get_access_token +from comps.cores.mega.utils import ConfigError, get_access_token, load_model_configs from comps.cores.proto.api_protocol import ChatCompletionRequest logger = CustomLogger("llm_tgi") logflag = os.getenv("LOGFLAG", False) # Environment variables +MODEL_CONFIGS = os.getenv("MODEL_CONFIGS") +DEFAULT_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") TOKEN_URL = os.getenv("TOKEN_URL") CLIENTID = os.getenv("CLIENTID") CLIENT_SECRET = os.getenv("CLIENT_SECRET") -llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") +# Validate and Load the models config if MODEL_CONFIGS is not null +configs_map = {} +if MODEL_CONFIGS: + try: + configs_map = load_model_configs(MODEL_CONFIGS) + except ConfigError as e: + logger.error(f"Failed to load model configurations: {e}") + raise ConfigError(f"Failed to load model configurations: {e}") + + +def get_llm_endpoint(model): + if not MODEL_CONFIGS: + return DEFAULT_ENDPOINT + try: + return configs_map.get(model).get("endpoint") + except ConfigError as e: + logger.error(f"Input model {model} not present in model_configs. Error {e}") + raise ConfigError(f"Input model {model} not present in model_configs") @register_microservice( @@ -54,7 +73,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche headers = {} if access_token: headers = {"Authorization": f"Bearer {access_token}"} - + llm_endpoint = get_llm_endpoint(input.model) llm = AsyncInferenceClient(model=llm_endpoint, timeout=600, headers=headers) prompt_template = None @@ -73,7 +92,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche docs = [doc.text for doc in input.retrieved_docs] if logflag: logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}") - prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs) + prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs, input.model) # use default llm parameters for inferencing new_input = LLMParamsDoc(query=prompt) if logflag: @@ -126,7 +145,7 @@ async def stream_generator(): else: if input.documents: # use rag default template - prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents) + prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents, input.model) text_generation = await llm.text_generation( prompt=prompt, @@ -182,7 +201,7 @@ async def stream_generator(): else: if input.documents: # use rag default template - prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents) + prompt = ChatTemplate.generate_rag_prompt(input.messages, input.documents, input.model) chat_completion = client.completions.create( model="tgi", diff --git a/comps/llms/text-generation/tgi/template.py b/comps/llms/text-generation/tgi/template.py index 447efcc67..6d976106a 100644 --- a/comps/llms/text-generation/tgi/template.py +++ b/comps/llms/text-generation/tgi/template.py @@ -6,24 +6,31 @@ class ChatTemplate: @staticmethod - def generate_rag_prompt(question, documents): + def generate_rag_prompt(question, documents, model): context_str = "\n".join(documents) - if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: - # chinese context + if model == "meta-llama/Meta-Llama-3.1-70B-Instruct" or model == "meta-llama/Meta-Llama-3.1-8B-Instruct": template = """ -### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 -### 搜索结果:{context} -### 问题:{question} -### 回答: -""" + <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|> + Question: {question} + Context: {context} + Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""" else: - template = """ -### You are a helpful, respectful and honest assistant to help the user with questions. \ -Please refer to the search results obtained from the local knowledge base. \ -But be careful to not incorporate the information that you think is not relevant to the question. \ -If you don't know the answer to a question, please don't share false information. \n -### Search results: {context} \n -### Question: {question} \n -### Answer: -""" + if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: + # chinese context + template = """ + ### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 + ### 搜索结果:{context} + ### 问题:{question} + ### 回答: + """ + else: + template = """ + ### You are a helpful, respectful and honest assistant to help the user with questions. \ + Please refer to the search results obtained from the local knowledge base. \ + But be careful to not incorporate the information that you think is not relevant to the question. \ + If you don't know the answer to a question, please don't share false information. \n + ### Search results: {context} \n + ### Question: {question} \n + ### Answer: + """ return template.format(context=context_str, question=question)