From f2a21e4bfddf678629f706d5c5b647a6ae820236 Mon Sep 17 00:00:00 2001 From: letonghan Date: Wed, 21 Aug 2024 17:10:24 +0800 Subject: [PATCH 1/2] support no_rerank input in llm Signed-off-by: letonghan --- comps/llms/text-generation/tgi/README.md | 12 ++-- .../tgi/docker_compose_llm.yaml | 8 ++- comps/llms/text-generation/tgi/llm.py | 66 ++++++++++++++++--- 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/comps/llms/text-generation/tgi/README.md b/comps/llms/text-generation/tgi/README.md index cca4d1fa4b..f34dd0374f 100644 --- a/comps/llms/text-generation/tgi/README.md +++ b/comps/llms/text-generation/tgi/README.md @@ -16,9 +16,6 @@ pip install -r requirements.txt ```bash export HF_TOKEN=${your_hf_api_token} -export LANGCHAIN_TRACING_V2=true -export LANGCHAIN_API_KEY=${your_langchain_api_key} -export LANGCHAIN_PROJECT="opea/gen-ai-comps:llms" docker run -p 8008:80 -v ./data:/data --name tgi_service --shm-size 1g ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id ${your_hf_llm_model} ``` @@ -50,9 +47,6 @@ In order to start TGI and LLM services, you need to setup the following environm export HF_TOKEN=${your_hf_api_token} export TGI_LLM_ENDPOINT="http://${your_ip}:8008" export LLM_MODEL_ID=${your_hf_llm_model} -export LANGCHAIN_TRACING_V2=true -export LANGCHAIN_API_KEY=${your_langchain_api_key} -export LANGCHAIN_PROJECT="opea/llms" ``` ### 2.2 Build Docker Image @@ -116,6 +110,12 @@ curl http://${your_ip}:9000/v1/chat/completions \ -X POST \ -d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \ -H 'Content-Type: application/json' + +# consume with SearchedDoc +curl http://${your_ip}:9000/v1/chat/completions \ + -X POST \ + -d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \ + -H 'Content-Type: application/json' ``` ### 4. Validated Model diff --git a/comps/llms/text-generation/tgi/docker_compose_llm.yaml b/comps/llms/text-generation/tgi/docker_compose_llm.yaml index 9551979a73..36269aeeaf 100644 --- a/comps/llms/text-generation/tgi/docker_compose_llm.yaml +++ b/comps/llms/text-generation/tgi/docker_compose_llm.yaml @@ -12,6 +12,13 @@ services: volumes: - "./data:/data" shm_size: 1g + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HF_TOKEN: ${HF_TOKEN} + HF_HUB_DISABLE_PROGRESS_BARS: 1 + HF_HUB_ENABLE_HF_TRANSFER: 0 command: --model-id ${LLM_MODEL_ID} llm: image: opea/llm-tgi:latest @@ -25,7 +32,6 @@ services: https_proxy: ${https_proxy} TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT} HF_TOKEN: ${HF_TOKEN} - LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY} restart: unless-stopped networks: diff --git a/comps/llms/text-generation/tgi/llm.py b/comps/llms/text-generation/tgi/llm.py index dd4d93e322..66d1135a4e 100644 --- a/comps/llms/text-generation/tgi/llm.py +++ b/comps/llms/text-generation/tgi/llm.py @@ -14,6 +14,7 @@ from comps import ( CustomLogger, GeneratedDoc, + SearchedDoc, LLMParamsDoc, ServiceType, opea_microservices, @@ -41,18 +42,65 @@ port=9000, ) @register_statistics(names=["opea_service@llm_tgi"]) -async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): +async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]): if logflag: logger.info(input) prompt_template = None - if input.chat_template: + if not isinstance(input, SearchedDoc) and input.chat_template: prompt_template = PromptTemplate.from_template(input.chat_template) input_variables = prompt_template.input_variables stream_gen_time = [] start = time.time() - if isinstance(input, LLMParamsDoc): + if isinstance(input, SearchedDoc): + if logflag: + logger.info(f"[ SearchedDoc ] input from retriever microservice") + prompt = input.initial_query + if input.retrieved_docs: + if logflag: + logger.info(f"[ SearchedDoc ] retrieved docs: {input.retrieved_docs}") + for doc in input.retrieved_docs: + logger.info(f"[ SearchedDoc ] {doc}") + prompt = ChatTemplate.generate_rag_prompt(input.initial_query, input.retrieved_docs[0].text) + # use default llm parameters for inferencing + new_input = LLMParamsDoc(query=prompt) + if logflag: + logger.info(f"[ SearchedDoc ] final input: {new_input}") + text_generation = await llm.text_generation( + prompt=prompt, + stream=new_input.streaming, + max_new_tokens=new_input.max_new_tokens, + repetition_penalty=new_input.repetition_penalty, + temperature=new_input.temperature, + top_k=new_input.top_k, + top_p=new_input.top_p, + ) + if new_input.streaming: + async def stream_generator(): + chat_response = "" + async for text in text_generation: + stream_gen_time.append(time.time() - start) + chat_response += text + chunk_repr = repr(text.encode("utf-8")) + if logflag: + logger.info(f"[ SearchedDoc ] chunk:{chunk_repr}") + yield f"data: {chunk_repr}\n\n" + if logflag: + logger.info(f"[ SearchedDoc ] stream response: {chat_response}") + statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0]) + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + statistics_dict["opea_service@llm_tgi"].append_latency(time.time() - start, None) + if logflag: + logger.info(text_generation) + return GeneratedDoc(text=text_generation, prompt=new_input.query) + + elif isinstance(input, LLMParamsDoc): + if logflag: + logger.info(f"[ LLMParamsDoc ] input from rerank microservice") prompt = input.query if prompt_template: if sorted(input_variables) == ["context", "question"]: @@ -60,7 +108,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): elif input_variables == ["question"]: prompt = prompt_template.format(question=input.query) else: - logger.info(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + logger.info(f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']") else: if input.documents: # use rag default template @@ -84,10 +132,10 @@ async def stream_generator(): chat_response += text chunk_repr = repr(text.encode("utf-8")) if logflag: - logger.info(f"[llm - chat_stream] chunk:{chunk_repr}") + logger.info(f"[ LLMParamsDoc ] chunk:{chunk_repr}") yield f"data: {chunk_repr}\n\n" if logflag: - logger.info(f"[llm - chat_stream] stream response: {chat_response}") + logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}") statistics_dict["opea_service@llm_tgi"].append_latency(stream_gen_time[-1], stream_gen_time[0]) yield "data: [DONE]\n\n" @@ -99,6 +147,8 @@ async def stream_generator(): return GeneratedDoc(text=text_generation, prompt=input.query) else: + if logflag: + logger.info(f"[ ChatCompletionRequest ] input in opea format") client = OpenAI( api_key="EMPTY", base_url=llm_endpoint + "/v1", @@ -113,7 +163,7 @@ async def stream_generator(): prompt = prompt_template.format(question=input.messages) else: logger.info( - f"{prompt_template} not used, we only support 2 input variables ['question', 'context']" + f"[ ChatCompletionRequest ] {prompt_template} not used, we only support 2 input variables ['question', 'context']" ) else: if input.documents: @@ -152,7 +202,7 @@ async def stream_generator(): if input_variables == ["context"]: system_prompt = prompt_template.format(context="\n".join(input.documents)) else: - logger.info(f"{prompt_template} not used, only support 1 input variables ['context']") + logger.info(f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']") input.messages.insert(0, {"role": "system", "content": system_prompt}) From 043f9f7e75e5ba847c7c15f8695cf7e9c1534c70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 09:17:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/llms/text-generation/tgi/llm.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comps/llms/text-generation/tgi/llm.py b/comps/llms/text-generation/tgi/llm.py index 66d1135a4e..b7ec7b2eae 100644 --- a/comps/llms/text-generation/tgi/llm.py +++ b/comps/llms/text-generation/tgi/llm.py @@ -14,8 +14,8 @@ from comps import ( CustomLogger, GeneratedDoc, - SearchedDoc, LLMParamsDoc, + SearchedDoc, ServiceType, opea_microservices, register_microservice, @@ -55,7 +55,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche if isinstance(input, SearchedDoc): if logflag: - logger.info(f"[ SearchedDoc ] input from retriever microservice") + logger.info("[ SearchedDoc ] input from retriever microservice") prompt = input.initial_query if input.retrieved_docs: if logflag: @@ -77,6 +77,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche top_p=new_input.top_p, ) if new_input.streaming: + async def stream_generator(): chat_response = "" async for text in text_generation: @@ -100,7 +101,7 @@ async def stream_generator(): elif isinstance(input, LLMParamsDoc): if logflag: - logger.info(f"[ LLMParamsDoc ] input from rerank microservice") + logger.info("[ LLMParamsDoc ] input from rerank microservice") prompt = input.query if prompt_template: if sorted(input_variables) == ["context", "question"]: @@ -108,7 +109,9 @@ async def stream_generator(): elif input_variables == ["question"]: prompt = prompt_template.format(question=input.query) else: - logger.info(f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']") + logger.info( + f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']" + ) else: if input.documents: # use rag default template @@ -148,7 +151,7 @@ async def stream_generator(): else: if logflag: - logger.info(f"[ ChatCompletionRequest ] input in opea format") + logger.info("[ ChatCompletionRequest ] input in opea format") client = OpenAI( api_key="EMPTY", base_url=llm_endpoint + "/v1", @@ -202,7 +205,9 @@ async def stream_generator(): if input_variables == ["context"]: system_prompt = prompt_template.format(context="\n".join(input.documents)) else: - logger.info(f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']") + logger.info( + f"[ ChatCompletionRequest ] {prompt_template} not used, only support 1 input variables ['context']" + ) input.messages.insert(0, {"role": "system", "content": system_prompt})