From 45d0002057520de5d12cb549a370a28ba1bb6e93 Mon Sep 17 00:00:00 2001 From: XinyaoWa Date: Fri, 20 Dec 2024 11:03:54 +0800 Subject: [PATCH] DocSum Long Context add auto mode (#1046) * docsum refine mode promt update Signed-off-by: Xinyao Wang * docsum vllm requirement update Signed-off-by: Xinyao Wang * docsum add auto mode Signed-off-by: Xinyao Wang * fix bug Signed-off-by: Xinyao Wang * fix bug Signed-off-by: Xinyao Wang * fix readme Signed-off-by: Xinyao Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine Signed-off-by: Xinyao Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xinyao Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/proto/docarray.py | 2 +- .../summarization/tgi/langchain/README.md | 10 ++- comps/llms/summarization/tgi/langchain/llm.py | 73 ++++++++++++++----- .../summarization/vllm/langchain/README.md | 10 ++- .../llms/summarization/vllm/langchain/llm.py | 73 ++++++++++++++----- .../vllm/langchain/requirements.txt | 1 + .../test_llms_summarization_tgi_langchain.sh | 8 +- 7 files changed, 129 insertions(+), 48 deletions(-) diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 7bbf02a85..7cb33c2ef 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -213,7 +213,7 @@ def chat_template_must_contain_variables(cls, v): class DocSumLLMParams(LLMParamsDoc): - summary_type: str = "stuff" # can be "truncate", "map_reduce", "refine" + summary_type: str = "auto" # can be "auto", "stuff", "truncate", "map_reduce", "refine" chunk_size: int = -1 chunk_overlap: int = -1 diff --git a/comps/llms/summarization/tgi/langchain/README.md b/comps/llms/summarization/tgi/langchain/README.md index 094247834..888b4adce 100644 --- a/comps/llms/summarization/tgi/langchain/README.md +++ b/comps/llms/summarization/tgi/langchain/README.md @@ -98,7 +98,7 @@ In DocSum microservice, except for basic LLM parameters, we also support several If you want to deal with long context, can select suitable summary type, details in section 3.2.2. -- "summary_type": can be "stuff", "truncate", "map_reduce", "refine", default is "stuff" +- "summary_type": can be "auto", "stuff", "truncate", "map_reduce", "refine", default is "auto" - "chunk_size": max token length for each chunk. Set to be different default value according to "summary_type". - "chunk_overlap": overlap token length between each chunk, default is 0.1\*chunk_size @@ -126,9 +126,13 @@ curl http://${your_ip}:9000/v1/chat/docsum \ #### 3.2.2 Long context summarization with "summary_type" -"summary_type" is set to be "stuff" by default, which will let LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context. +**summary_type=auto** -When deal with long context, you can set "summary_type" to one of "truncate", "map_reduce" and "refine" for better performance. +"summary_type" is set to be "auto" by default, in this mode we will check input token length, if it exceed `MAX_INPUT_TOKENS`, `summary_type` will automatically be set to `refine` mode, otherwise will be set to `stuff` mode. + +**summary_type=stuff** + +In this mode LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context. **summary_type=truncate** diff --git a/comps/llms/summarization/tgi/langchain/llm.py b/comps/llms/summarization/tgi/langchain/llm.py index 60de26aa8..465e5d26d 100644 --- a/comps/llms/summarization/tgi/langchain/llm.py +++ b/comps/llms/summarization/tgi/langchain/llm.py @@ -42,26 +42,43 @@ 概况:""" -templ_refine_en = """\ -Your job is to produce a final summary. -We have provided an existing summary up to a certain point: {existing_answer} -We have the opportunity to refine the existing summary (only if needed) with some more context below. ------------- -{text} ------------- -Given the new context, refine the original summary. -If the context isn't useful, return the original summary.\ +templ_refine_en = """Your job is to produce a final summary. +We have provided an existing summary up to a certain point, then we will provide more context. +You need to refine the existing summary (only if needed) with new context and generate a final summary. + + +Existing Summary: +"{existing_answer}" + + + +New Context: +"{text}" + + + +Final Summary: + """ templ_refine_zh = """\ 你的任务是生成一个最终摘要。 -我们已经提供了部分摘要:{existing_answer} -如果有需要的话,可以通过以下更多上下文来完善现有摘要。 ------------- -{text} ------------- -根据新上下文,完善原始摘要。 -如果上下文无用,则返回原始摘要。\ +我们已经处理好部分文本并生成初始摘要, 并提供了新的未处理文本 +你需要根据新提供的文本,结合初始摘要,生成一个最终摘要。 + + +初始摘要: +"{existing_answer}" + + + +新的文本: +"{text}" + + + +最终摘要: + """ @@ -76,6 +93,25 @@ async def llm_generate(input: DocSumLLMParams): if logflag: logger.info(input) + ### check summary type + summary_types = ["auto", "stuff", "truncate", "map_reduce", "refine"] + if input.summary_type not in summary_types: + raise NotImplementedError(f"Please specify the summary_type in {summary_types}") + if input.summary_type == "auto": ### Check input token length in auto mode + token_len = len(tokenizer.encode(input.query)) + if token_len > MAX_INPUT_TOKENS + 50: + input.summary_type = "refine" + if logflag: + logger.info( + f"Input token length {token_len} exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'refine' mode." + ) + else: + input.summary_type = "stuff" + if logflag: + logger.info( + f"Input token length {token_len} not exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'stuff' mode." + ) + if input.language in ["en", "auto"]: templ = templ_en templ_refine = templ_refine_en @@ -98,7 +134,7 @@ async def llm_generate(input: DocSumLLMParams): ## Split text if input.summary_type == "stuff": text_splitter = CharacterTextSplitter() - elif input.summary_type in ["truncate", "map_reduce", "refine"]: + else: if input.summary_type == "refine": if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128: raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)") @@ -119,8 +155,7 @@ async def llm_generate(input: DocSumLLMParams): if logflag: logger.info(f"set chunk size to: {chunk_size}") logger.info(f"set chunk overlap to: {chunk_overlap}") - else: - raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"') + texts = text_splitter.split_text(input.query) docs = [Document(page_content=t) for t in texts] if logflag: diff --git a/comps/llms/summarization/vllm/langchain/README.md b/comps/llms/summarization/vllm/langchain/README.md index 435b6e0e2..e0d591b69 100644 --- a/comps/llms/summarization/vllm/langchain/README.md +++ b/comps/llms/summarization/vllm/langchain/README.md @@ -97,7 +97,7 @@ In DocSum microservice, except for basic LLM parameters, we also support several If you want to deal with long context, can select suitable summary type, details in section 3.2.2. -- "summary_type": can be "stuff", "truncate", "map_reduce", "refine", default is "stuff" +- "summary_type": can be "auto", "stuff", "truncate", "map_reduce", "refine", default is "auto" - "chunk_size": max token length for each chunk. Set to be different default value according to "summary_type". - "chunk_overlap": overlap token length between each chunk, default is 0.1\*chunk_size @@ -125,9 +125,13 @@ curl http://${your_ip}:9000/v1/chat/docsum \ #### 3.2.2 Long context summarization with "summary_type" -"summary_type" is set to be "stuff" by default, which will let LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context. +**summary_type=auto** -When deal with long context, you can set "summary_type" to one of "truncate", "map_reduce" and "refine" for better performance. +"summary_type" is set to be "auto" by default, in this mode we will check input token length, if it exceed `MAX_INPUT_TOKENS`, `summary_type` will automatically be set to `refine` mode, otherwise will be set to `stuff` mode. + +**summary_type=stuff** + +In this mode LLM generate summary based on complete input text. In this case please carefully set `MAX_INPUT_TOKENS` and `MAX_TOTAL_TOKENS` according to your model and device memory, otherwise it may exceed LLM context limit and raise error when meet long context. **summary_type=truncate** diff --git a/comps/llms/summarization/vllm/langchain/llm.py b/comps/llms/summarization/vllm/langchain/llm.py index 2a74acb6e..f134a75a5 100644 --- a/comps/llms/summarization/vllm/langchain/llm.py +++ b/comps/llms/summarization/vllm/langchain/llm.py @@ -43,26 +43,43 @@ 概况:""" -templ_refine_en = """\ -Your job is to produce a final summary. -We have provided an existing summary up to a certain point: {existing_answer} -We have the opportunity to refine the existing summary (only if needed) with some more context below. ------------- -{text} ------------- -Given the new context, refine the original summary. -If the context isn't useful, return the original summary.\ +templ_refine_en = """Your job is to produce a final summary. +We have provided an existing summary up to a certain point, then we will provide more context. +You need to refine the existing summary (only if needed) with new context and generate a final summary. + + +Existing Summary: +"{existing_answer}" + + + +New Context: +"{text}" + + + +Final Summary: + """ templ_refine_zh = """\ 你的任务是生成一个最终摘要。 -我们已经提供了部分摘要:{existing_answer} -如果有需要的话,可以通过以下更多上下文来完善现有摘要。 ------------- -{text} ------------- -根据新上下文,完善原始摘要。 -如果上下文无用,则返回原始摘要。\ +我们已经处理好部分文本并生成初始摘要, 并提供了新的未处理文本 +你需要根据新提供的文本,结合初始摘要,生成一个最终摘要。 + + +初始摘要: +"{existing_answer}" + + + +新的文本: +"{text}" + + + +最终摘要: + """ @@ -77,6 +94,25 @@ async def llm_generate(input: DocSumLLMParams): if logflag: logger.info(input) + ### check summary type + summary_types = ["auto", "stuff", "truncate", "map_reduce", "refine"] + if input.summary_type not in summary_types: + raise NotImplementedError(f"Please specify the summary_type in {summary_types}") + if input.summary_type == "auto": ### Check input token length in auto mode + token_len = len(tokenizer.encode(input.query)) + if token_len > MAX_INPUT_TOKENS + 50: + input.summary_type = "refine" + if logflag: + logger.info( + f"Input token length {token_len} exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'refine' mode." + ) + else: + input.summary_type = "stuff" + if logflag: + logger.info( + f"Input token length {token_len} not exceed MAX_INPUT_TOKENS + 50 {MAX_INPUT_TOKENS+50}, auto switch to 'stuff' mode." + ) + if input.language in ["en", "auto"]: templ = templ_en templ_refine = templ_refine_en @@ -99,7 +135,7 @@ async def llm_generate(input: DocSumLLMParams): ## Split text if input.summary_type == "stuff": text_splitter = CharacterTextSplitter() - elif input.summary_type in ["truncate", "map_reduce", "refine"]: + else: if input.summary_type == "refine": if MAX_TOTAL_TOKENS <= 2 * input.max_tokens + 128: raise RuntimeError("In Refine mode, Please set MAX_TOTAL_TOKENS larger than (max_tokens * 2 + 128)") @@ -120,8 +156,7 @@ async def llm_generate(input: DocSumLLMParams): if logflag: logger.info(f"set chunk size to: {chunk_size}") logger.info(f"set chunk overlap to: {chunk_overlap}") - else: - raise NotImplementedError('Please specify the summary_type in "stuff", "truncate", "map_reduce", "refine"') + texts = text_splitter.split_text(input.query) docs = [Document(page_content=t) for t in texts] if logflag: diff --git a/comps/llms/summarization/vllm/langchain/requirements.txt b/comps/llms/summarization/vllm/langchain/requirements.txt index e074ba8c8..169461863 100644 --- a/comps/llms/summarization/vllm/langchain/requirements.txt +++ b/comps/llms/summarization/vllm/langchain/requirements.txt @@ -1,5 +1,6 @@ docarray[full] fastapi +httpx==0.27.2 huggingface_hub langchain #==0.1.12 langchain-huggingface diff --git a/tests/llms/test_llms_summarization_tgi_langchain.sh b/tests/llms/test_llms_summarization_tgi_langchain.sh index ed1fc206e..d805b7361 100644 --- a/tests/llms/test_llms_summarization_tgi_langchain.sh +++ b/tests/llms/test_llms_summarization_tgi_langchain.sh @@ -2,7 +2,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -set -x +set -xe WORKPATH=$(dirname "$PWD") ip_address=$(hostname -I | awk '{print $1}') @@ -30,7 +30,7 @@ function start_service() { export TGI_LLM_ENDPOINT="http://${ip_address}:${tgi_endpoint_port}" sum_port=5076 - docker run -d --name="test-comps-llm-sum-tgi-server" -p ${sum_port}:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT -e LLM_MODEL_ID=$LLM_MODEL_ID -e MAX_INPUT_TOKENS=$MAX_INPUT_TOKENS -e MAX_TOTAL_TOKENS=$MAX_TOTAL_TOKENS -e HUGGINGFACEHUB_API_TOKEN=$HF_TOKEN opea/llm-sum-tgi:comps + docker run -d --name="test-comps-llm-sum-tgi-server" -p ${sum_port}:9000 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT -e LLM_MODEL_ID=$LLM_MODEL_ID -e MAX_INPUT_TOKENS=$MAX_INPUT_TOKENS -e MAX_TOTAL_TOKENS=$MAX_TOTAL_TOKENS -e HUGGINGFACEHUB_API_TOKEN=$HF_TOKEN -e LOGFLAG=True opea/llm-sum-tgi:comps # check whether tgi is fully ready n=0 @@ -61,10 +61,12 @@ function validate_services() { local CONTENT=$(curl -s -X POST -d "$INPUT_DATA" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/${SERVICE_NAME}.log) + echo $CONTENT + if echo "$CONTENT" | grep -q "$EXPECTED_RESULT"; then echo "[ $SERVICE_NAME ] Content is as expected." else - echo "[ $SERVICE_NAME ] Content does not match the expected result: $CONTENT" + echo "[ $SERVICE_NAME ] Content does not match the expected result" docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log exit 1 fi