From b7d9d222d0ad868076632b09fafe6b5521f6cd8c Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Fri, 23 Aug 2024 17:09:05 +0800 Subject: [PATCH 01/18] feat(llm): add reranker --- .../hugegraph_llm/api/models/rag_requests.py | 12 + .../src/hugegraph_llm/api/rag_api.py | 68 +++- .../src/hugegraph_llm/config/config.py | 13 +- .../src/hugegraph_llm/demo/rag_web_demo.py | 355 +++++++++++++++--- .../models/rerankers/__init__.py | 16 + .../hugegraph_llm/models/rerankers/cohere.py | 64 ++++ .../models/rerankers/init_reranker.py | 33 ++ .../models/rerankers/siliconflow.py | 61 +++ .../operators/common_op/merge_dedup_rerank.py | 60 +-- .../hugegraph_llm/operators/graph_rag_task.py | 52 +-- .../operators/llm_op/answer_synthesize.py | 7 +- 11 files changed, 616 insertions(+), 125 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py create mode 100644 hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py create mode 100644 hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py create mode 100644 hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 47610f55..0b8c930c 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -51,3 +51,15 @@ class LLMConfigRequest(BaseModel): # ollama-only properties host: str = None port: str = None + +class EmbeddingConfigRequest(BaseModel): + llm_type: str + # The common parameters shared by OpenAI, Qianfan Wenxin, + # and OLLAMA platforms. + api_key: str + +class RerankerConfigRequest(BaseModel): + reranker_model: str + reranker_type: str + api_key: str + api_base: str \ No newline at end of file diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index e5836192..f7d4fc76 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -18,25 +18,43 @@ from fastapi import status, APIRouter from hugegraph_llm.api.exceptions.rag_exceptions import generate_response -from hugegraph_llm.api.models.rag_requests import RAGRequest, GraphConfigRequest, LLMConfigRequest +from hugegraph_llm.api.models.rag_requests import ( + RAGRequest, + GraphConfigRequest, + LLMConfigRequest, + RerankerConfigRequest, +) from hugegraph_llm.api.models.rag_response import RAGResponse from hugegraph_llm.config import settings -def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf): +def rag_http_api( + router: APIRouter, + rag_answer_func, + apply_graph_conf, + apply_llm_conf, + apply_embedding_conf, + apply_reranker_conf, +): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): - result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector) + result = rag_answer_func( + req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector + ) return { key: value - for key, value in zip(["raw_llm", "vector_only", "graph_only", "graph_vector"], result) + for key, value in zip( + ["raw_llm", "vector_only", "graph_only", "graph_vector"], result + ) if getattr(req, key) } @router.post("/config/graph", status_code=status.HTTP_201_CREATED) def graph_config_api(req: GraphConfigRequest): # Accept status code - res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") + res = apply_graph_conf( + req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/llm", status_code=status.HTTP_201_CREATED) @@ -45,12 +63,24 @@ def llm_config_api(req: LLMConfigRequest): if req.llm_type == "openai": res = apply_llm_conf( - req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http" + req.api_key, + req.api_base, + req.language_model, + req.max_tokens, + origin_call="http", ) elif req.llm_type == "qianfan_wenxin": - res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") + res = apply_llm_conf( + req.api_key, + req.secret_key, + req.language_model, + None, + origin_call="http", + ) else: - res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") + res = apply_llm_conf( + req.host, req.port, req.language_model, None, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) @@ -58,9 +88,25 @@ def embedding_config_api(req: LLMConfigRequest): settings.embedding_type = req.llm_type if req.llm_type == "openai": - res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") + res = apply_embedding_conf( + req.api_key, req.api_base, req.language_model, origin_call="http" + ) elif req.llm_type == "qianfan_wenxin": - res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") + res = apply_embedding_conf( + req.api_key, req.api_base, None, origin_call="http" + ) else: - res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") + res = apply_embedding_conf( + req.host, req.port, req.language_model, origin_call="http" + ) + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) + def rerank_config_api(req: RerankerConfigRequest): + settings.reranker_type = req.reranker_type + + if req.reranker_type == "cohere": + res = apply_reranker_conf( + req.api_key, req.reranker_model, req.api_base, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py index 2fd8262b..778e570c 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config.py @@ -35,22 +35,27 @@ class Config: # env_path: Optional[str] = ".env" llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai" embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai" + reranker_type: Optional[Literal["cohere", "siliconflow"]] = "cohere" # 1. OpenAI settings openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_language_model: Optional[str] = "gpt-4o-mini" openai_embedding_model: Optional[str] = "text-embedding-3-small" openai_max_tokens: int = 4096 - # 2. Ollama settings + # 2. Rerank settings + cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank") + reranker_api_key: Optional[str] = None + reranker_model: Optional[str] = "rerank-multilingual-v3.0" + # 3. Ollama settings ollama_host: Optional[str] = "127.0.0.1" ollama_port: Optional[int] = 11434 ollama_language_model: Optional[str] = None ollama_embedding_model: Optional[str] = None - # 3. QianFan/WenXin settings + # 4. QianFan/WenXin settings qianfan_api_key: Optional[str] = None qianfan_secret_key: Optional[str] = None qianfan_access_token: Optional[str] = None - # 3.1 url settings + # 4.1 url settings qianfan_url_prefix: Optional[str] = ( "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop" ) @@ -59,7 +64,7 @@ class Config: qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/" # https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu qianfan_embedding_model: Optional[str] = "embedding-v1" - # 4. ZhiPu(GLM) settings + # 5. ZhiPu(GLM) settings zhipu_api_key: Optional[str] = None zhipu_language_model: Optional[str] = "glm-4" zhipu_embedding_model: Optional[str] = "embedding-2" diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 74704ba2..c7d57789 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -39,7 +39,11 @@ from hugegraph_llm.operators.kg_construction_task import KgBuilder from hugegraph_llm.operators.llm_op.property_graph_extract import SCHEMA_EXAMPLE_PROMPT from hugegraph_llm.utils.hugegraph_utils import get_hg_client -from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data, run_gremlin_query, clean_hg_data +from hugegraph_llm.utils.hugegraph_utils import ( + init_hg_test_data, + run_gremlin_query, + clean_hg_data, +) from hugegraph_llm.utils.log import log from hugegraph_llm.utils.vector_index_utils import clean_vector_index @@ -50,6 +54,7 @@ def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)): correct_token = os.getenv("TOKEN") if credentials.credentials != correct_token: from fastapi import HTTPException + raise HTTPException( status_code=401, detail=f"Invalid token {credentials.credentials}, please contact the admin", @@ -58,7 +63,13 @@ def authenticate(credentials: HTTPAuthorizationCredentials = Depends(sec)): def rag_answer( - text: str, raw_answer: bool, vector_only_answer: bool, graph_only_answer: bool, graph_vector_answer: bool + text: str, + raw_answer: bool, + vector_only_answer: bool, + graph_only_answer: bool, + graph_vector_answer: bool, + graph_ratio: float, + rerank_method: str, ) -> tuple: vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer @@ -71,7 +82,7 @@ def rag_answer( searcher.query_vector_index_for_rag() if graph_search: searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag() - searcher.merge_dedup_rerank().synthesize_answer( + searcher.merge_dedup_rerank(graph_ratio, rerank_method).synthesize_answer( raw_answer=raw_answer, vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, @@ -79,7 +90,12 @@ def rag_answer( ) try: - context = searcher.run(verbose=True, query=text) + context = searcher.run( + verbose=True, + query=text, + vector_search=vector_search, + graph_search=graph_search, + ) return ( context.get("raw_answer", ""), context.get("vector_only_answer", ""), @@ -95,10 +111,10 @@ def rag_answer( def build_kg( # pylint: disable=too-many-branches - files: Union[NamedString, List[NamedString]], - schema: str, - example_prompt: str, - build_mode: str + files: Union[NamedString, List[NamedString]], + schema: str, + example_prompt: str, + build_mode: str, ) -> str: if isinstance(files, NamedString): files = [files] @@ -158,15 +174,20 @@ def build_kg( # pylint: disable=too-many-branches raise gr.Error(str(e)) -def test_api_connection(url, method="GET", - headers=None, params=None, body=None, auth=None, origin_call=None) -> int: +def test_api_connection( + url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None +) -> int: # TODO: use fastapi.request / starlette instead? log.debug("Request URL: %s", url) try: if method.upper() == "GET": - resp = requests.get(url, headers=headers, params=params, timeout=5, auth=auth) + resp = requests.get( + url, headers=headers, params=params, timeout=5, auth=auth + ) elif method.upper() == "POST": - resp = requests.post(url, headers=headers, params=params, json=body, timeout=5, auth=auth) + resp = requests.post( + url, headers=headers, params=params, json=body, timeout=5, auth=auth + ) else: raise ValueError("Unsupported HTTP method, please use GET/POST instead") except requests.exceptions.RequestException as e: @@ -185,7 +206,10 @@ def test_api_connection(url, method="GET", log.error(msg) # TODO: Only the message returned by rag can be processed, and the other return values can't be processed if origin_call is None: - raise gr.Error(json.loads(resp.text).get("message", msg)) + try: + raise gr.Error(json.loads(resp.text).get("message", msg)) + except json.decoder.JSONDecodeError: + raise gr.Error(resp.text) return resp.status_code @@ -196,10 +220,14 @@ def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int: params = { "grant_type": "client_credentials", "client_id": arg1, - "client_secret": arg2 + "client_secret": arg2, } - status_code = test_api_connection("https://aip.baidubce.com/oauth/2.0/token", "POST", params=params, - origin_call=origin_call) + status_code = test_api_connection( + "https://aip.baidubce.com/oauth/2.0/token", + "POST", + params=params, + origin_call=origin_call, + ) return status_code @@ -212,7 +240,9 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: settings.openai_embedding_model = arg3 test_url = settings.openai_api_base + "/models" headers = {"Authorization": f"Bearer {arg1}"} - status_code = test_api_connection(test_url, headers=headers, origin_call=origin_call) + status_code = test_api_connection( + test_url, headers=headers, origin_call=origin_call + ) elif embedding_option == "qianfan_wenxin": status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call) settings.qianfan_embedding_model = arg3 @@ -220,7 +250,40 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_embedding_model = arg3 - status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) + status_code = test_api_connection( + f"http://{arg1}:{arg2}", origin_call=origin_call + ) + settings.update_env() + gr.Info("Configured!") + return status_code + + +def apply_reranker_config(arg1, arg2, arg3: str | None = None, origin_call=None) -> int: + status_code = -1 + reranker_option = settings.reranker_type + if reranker_option == "cohere": + settings.reranker_api_key = arg1 + settings.reranker_model = arg2 + settings.cohere_base_url = arg3 + headers = {"Authorization": f"Bearer {arg1}"} + status_code = test_api_connection( + arg3.rsplit("/", 1)[0] + "/check-api-key", + method="POST", + headers=headers, + origin_call=origin_call, + ) + elif reranker_option == "siliconflow": + settings.reranker_api_key = arg1 + settings.reranker_model = arg2 + headers = { + "accept": "application/json", + "authorization": f"Bearer {arg1}", + } + status_code = test_api_connection( + "https://api.siliconflow.cn/v1/user/info", + headers=headers, + origin_call=origin_call, + ) settings.update_env() gr.Info("Configured!") return status_code @@ -257,23 +320,29 @@ def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int: settings.openai_max_tokens = int(arg4) test_url = settings.openai_api_base + "/models" headers = {"Authorization": f"Bearer {arg1}"} - status_code = test_api_connection(test_url, headers=headers, origin_call=origin_call) + status_code = test_api_connection( + test_url, headers=headers, origin_call=origin_call + ) elif llm_option == "qianfan_wenxin": status_code = config_qianfan_model(arg1, arg2, arg3, origin_call) elif llm_option == "ollama": settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_language_model = arg3 - status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) + status_code = test_api_connection( + f"http://{arg1}:{arg2}", origin_call=origin_call + ) gr.Info("Configured!") settings.update_env() return status_code def init_rag_ui() -> gr.Interface: - with gr.Blocks(theme='default', - title="HugeGraph RAG Platform", - css="footer {visibility: hidden}") as hugegraph_llm_ui: + with gr.Blocks( + theme="default", + title="HugeGraph RAG Platform", + css="footer {visibility: hidden}", + ) as hugegraph_llm_ui: gr.Markdown( """# HugeGraph LLM RAG Demo 1. Set up the HugeGraph server.""" @@ -289,10 +358,16 @@ def init_rag_ui() -> gr.Interface: ] graph_config_button = gr.Button("apply configuration") - graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member + graph_config_button.click( + apply_graph_config, inputs=graph_config_input + ) # pylint: disable=no-member gr.Markdown("2. Set up the LLM.") - llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type, label="LLM") + llm_dropdown = gr.Dropdown( + choices=["openai", "qianfan_wenxin", "ollama"], + value=settings.llm_type, + label="LLM", + ) @gr.render(inputs=[llm_dropdown]) def llm_settings(llm_type): @@ -300,9 +375,15 @@ def llm_settings(llm_type): if llm_type == "openai": with gr.Row(): llm_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key", type="password"), + gr.Textbox( + value=settings.openai_api_key, + label="api_key", + type="password", + ), gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_language_model, label="model_name"), + gr.Textbox( + value=settings.openai_language_model, label="model_name" + ), gr.Textbox(value=settings.openai_max_tokens, label="max_token"), ] elif llm_type == "ollama": @@ -310,15 +391,27 @@ def llm_settings(llm_type): llm_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_language_model, label="model_name"), + gr.Textbox( + value=settings.ollama_language_model, label="model_name" + ), gr.Textbox(value="", visible=False), ] elif llm_type == "qianfan_wenxin": with gr.Row(): llm_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), - gr.Textbox(value=settings.qianfan_language_model, label="model_name"), + gr.Textbox( + value=settings.qianfan_api_key, + label="api_key", + type="password", + ), + gr.Textbox( + value=settings.qianfan_secret_key, + label="secret_key", + type="password", + ), + gr.Textbox( + value=settings.qianfan_language_model, label="model_name" + ), gr.Textbox(value="", visible=False), ] log.debug(llm_config_input) @@ -326,11 +419,15 @@ def llm_settings(llm_type): llm_config_input = [] llm_config_button = gr.Button("apply configuration") - llm_config_button.click(apply_llm_config, inputs=llm_config_input) # pylint: disable=no-member + llm_config_button.click( + apply_llm_config, inputs=llm_config_input + ) # pylint: disable=no-member gr.Markdown("3. Set up the Embedding.") embedding_dropdown = gr.Dropdown( - choices=["openai", "qianfan_wenxin", "ollama"], value=settings.embedding_type, label="Embedding" + choices=["openai", "qianfan_wenxin", "ollama"], + value=settings.embedding_type, + label="Embedding", ) @gr.render(inputs=[embedding_dropdown]) @@ -339,23 +436,41 @@ def embedding_settings(embedding_type): if embedding_type == "openai": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key", type="password"), + gr.Textbox( + value=settings.openai_api_key, + label="api_key", + type="password", + ), gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_embedding_model, label="model_name"), + gr.Textbox( + value=settings.openai_embedding_model, label="model_name" + ), ] elif embedding_type == "qianfan_wenxin": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), - gr.Textbox(value=settings.qianfan_embedding_model, label="model_name"), + gr.Textbox( + value=settings.qianfan_api_key, + label="api_key", + type="password", + ), + gr.Textbox( + value=settings.qianfan_secret_key, + label="secret_key", + type="password", + ), + gr.Textbox( + value=settings.qianfan_embedding_model, label="model_name" + ), ] elif embedding_type == "ollama": with gr.Row(): embedding_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_embedding_model, label="model_name"), + gr.Textbox( + value=settings.ollama_embedding_model, label="model_name" + ), ] else: embedding_config_input = [] @@ -364,7 +479,54 @@ def embedding_settings(embedding_type): # Call the separate apply_embedding_configuration function here embedding_config_button.click( # pylint: disable=no-member - apply_embedding_config, inputs=embedding_config_input # pylint: disable=no-member + apply_embedding_config, + inputs=embedding_config_input, # pylint: disable=no-member + ) + + gr.Markdown("4. Set up the Reranker(Optional).") + reranker_dropdown = gr.Dropdown( + choices=["cohere", "siliconflow"], + value=settings.reranker_type, + label="Reranker", + ) + + @gr.render(inputs=[reranker_dropdown]) + def reranker_settings(reranker_type): + settings.reranker_type = reranker_type + if reranker_type == "cohere": + with gr.Row(): + reranker_config_input = [ + gr.Textbox( + value=settings.reranker_api_key, + label="api_key", + type="password", + ), + gr.Textbox(value=settings.reranker_model, label="model"), + gr.Textbox(value=settings.cohere_base_url, label="base_url"), + ] + elif reranker_type == "siliconflow": + with gr.Row(): + reranker_config_input = [ + gr.Textbox( + value=settings.reranker_api_key, + label="api_key", + type="password", + ), + gr.Textbox( + value="BAAI/bge-reranker-v2-m3", + label="model", + info="Please refer to https://siliconflow.cn/pricing", + ), + ] + else: + reranker_config_input = [] + + reranker_config_button = gr.Button("apply configuration") + + # Call the separate apply_reranker_configuration function here + reranker_config_button.click( # pylint: disable=no-member + apply_reranker_config, + inputs=reranker_config_input, # pylint: disable=no-member ) gr.Markdown( @@ -422,12 +584,20 @@ def embedding_settings(embedding_type): input_file = gr.File( value=[os.path.join(resource_path, "demo", "test.txt")], label="Doc(s) (multi-files can be selected together)", - file_count="multiple") + file_count="multiple", + ) input_schema = gr.Textbox(value=schema, label="Schema") - info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head") + info_extract_template = gr.Textbox( + value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head" + ) with gr.Column(): mode = gr.Radio( - choices=["Test Mode", "Import Mode", "Clear and Import", "Rebuild Vector"], + choices=[ + "Test Mode", + "Import Mode", + "Clear and Import", + "Rebuild Vector", + ], value="Test Mode", label="Build mode", ) @@ -435,22 +605,71 @@ def embedding_settings(embedding_type): with gr.Row(): out = gr.Textbox(label="Output", show_copy_button=True) btn.click( # pylint: disable=no-member - fn=build_kg, inputs=[input_file, input_schema, info_extract_template, mode], outputs=out + fn=build_kg, + inputs=[input_file, input_schema, info_extract_template, mode], + outputs=out, ) gr.Markdown("""## 2. RAG with HugeGraph 📖""") with gr.Row(): with gr.Column(scale=2): - inp = gr.Textbox(value="Tell me about Sarah.", label="Question", show_copy_button=True) + inp = gr.Textbox( + value="Tell me about Sarah.", + label="Question", + show_copy_button=True, + ) raw_out = gr.Textbox(label="Basic LLM Answer", show_copy_button=True) - vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True) - graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True) - graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True) + vector_only_out = gr.Textbox( + label="Vector-only Answer", show_copy_button=True + ) + graph_only_out = gr.Textbox( + label="Graph-only Answer", show_copy_button=True + ) + graph_vector_out = gr.Textbox( + label="Graph-Vector Answer", show_copy_button=True + ) with gr.Column(scale=1): - raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer") - vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") - graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer") - graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") + raw_radio = gr.Radio( + choices=[True, False], value=True, label="Basic LLM Answer" + ) + vector_only_radio = gr.Radio( + choices=[True, False], value=False, label="Vector-only Answer" + ) + graph_only_radio = gr.Radio( + choices=[True, False], value=False, label="Graph-only Answer" + ) + with gr.Row(): + + def toggle_slider(enable): + return gr.update(interactive=enable) + + graph_vector_radio = gr.Radio( + choices=[True, False], value=False, label="Graph-Vector Answer" + ) + graph_ratio = gr.Slider( + 0, + 1, + 0.5, + label="Graph Ratio", + step=0.1, + interactive=False, + ) + graph_vector_radio.change( + toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio + ) + with gr.Column(): + rerank_method = gr.Dropdown( + choices=["bleu", "reranker"], + value="bleu", + label="Rerank method", + ) + graph_strategy = gr.Checkbox( + value=False, label="Near neighbor first(Optional)", info="One-depth neighbors > two-depth neighbors" + ) + custom_related_information = gr.Text( + "", + label="Custom related information(Optional)", + ) btn = gr.Button("Answer Question") btn.click( # pylint: disable=no-member fn=rag_answer, @@ -460,6 +679,8 @@ def embedding_settings(embedding_type): vector_only_radio, graph_only_radio, graph_vector_radio, + graph_ratio, + rerank_method, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], ) @@ -467,17 +688,25 @@ def embedding_settings(embedding_type): gr.Markdown("""## 3. Others (🚧) """) with gr.Row(): with gr.Column(): - inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True) + inp = gr.Textbox( + value="g.V().limit(10)", + label="Gremlin query", + show_copy_button=True, + ) fmt = gr.Checkbox(label="Format JSON", value=True) out = gr.Textbox(label="Output", show_copy_button=True) btn = gr.Button("Run gremlin query on HugeGraph") - btn.click(fn=run_gremlin_query, inputs=[inp, fmt], outputs=out) # pylint: disable=no-member + btn.click( + fn=run_gremlin_query, inputs=[inp, fmt], outputs=out + ) # pylint: disable=no-member with gr.Row(): inp = [] out = gr.Textbox(label="Output", show_copy_button=True) btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)") - btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member + btn.click( + fn=init_hg_test_data, inputs=inp, outputs=out + ) # pylint: disable=no-member return hugegraph_llm_ui @@ -490,13 +719,25 @@ def embedding_settings(embedding_type): app_auth = APIRouter(dependencies=[Depends(authenticate)]) hugegraph_llm = init_rag_ui() - rag_http_api(app_auth, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config) + rag_http_api( + app_auth, + rag_answer, + apply_graph_config, + apply_llm_config, + apply_embedding_config, + apply_reranker_config, + ) app.include_router(app_auth) auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true" log.info("Authentication is %s.", "enabled" if auth_enabled else "disabled") # TODO: support multi-user login when need - app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", os.getenv("TOKEN")) if auth_enabled else None) + app = gr.mount_gradio_app( + app, + hugegraph_llm, + path="/", + auth=("rag", os.getenv("TOKEN")) if auth_enabled else None, + ) # Note: set reload to False in production environment uvicorn.run(app, host=args.host, port=args.port) diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py new file mode 100644 index 00000000..13a83393 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py new file mode 100644 index 00000000..001fddfd --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import requests +from typing import Optional +from hugegraph_llm.config import settings + + +class CohereReranker: + def __init__( + self, + api_key: str = settings.reranker_api_key, + base_url: str = settings.cohere_base_url, + model: str = settings.reranker_model, + ): + self.api_key = api_key + self.base_url = base_url + self.model = model + + def get_rerank_lists( + self, + query: str, + documents: list[str], + top_n: Optional[int] = None + ) -> list: + if not top_n: + top_n = len(documents) + assert top_n <= len( + documents + ), "'top_n' should be less than or equal to the number of documents" + + + url = self.base_url + headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + payload = { + "model": self.model, + "query": query, + "top_n": top_n, + "documents": documents + } + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() # Raise an error for bad status codes + results = response.json()["results"] + sorted_docs = [documents[item["index"]] for item in results] + return sorted_docs diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py new file mode 100644 index 00000000..d9397ed8 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from hugegraph_llm.models.rerankers.cohere import CohereReranker +from hugegraph_llm.models.rerankers.siliconflow import SiliconReranker +from hugegraph_llm.config import settings + +class Rerankers: + def __init__(self): + self.reranker_type = settings.reranker_type + + def get_reranker(self): + if self.reranker_type == "cohere": + return CohereReranker() + + if self.reranker_type == "siliconflow": + return SiliconReranker() + + raise Exception(f"reranker type is not supported !") diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py new file mode 100644 index 00000000..9fff28a2 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import requests +from typing import Optional +from hugegraph_llm.config import settings + + +class SiliconReranker: + def __init__( + self, + api_key: str = settings.reranker_api_key, + model: str = settings.reranker_model, + ): + self.api_key = api_key + self.model = model + + def get_rerank_lists( + self, query: str, documents: list[str], top_n: Optional[int] = None + ) -> list: + if not top_n: + top_n = len(documents) + assert top_n <= len( + documents + ), "'top_n' should be less than or equal to the number of documents" + + url = "https://api.siliconflow.cn/v1/rerank" + payload = { + "model": self.model, + "query": query, + "documents": documents, + "return_documents": False, + "max_chunks_per_doc": 1024, + "overlap_tokens": 80, + "top_n": top_n, + } + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.api_key}", + } + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() # Raise an error for bad status codes + results = response.json()["results"] + sorted_docs = [documents[item["index"]] for item in results] + return sorted_docs diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index e0124797..0530b066 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -16,14 +16,15 @@ # under the License. -from typing import Dict, Any, List, Literal +from typing import Literal, Dict, Any, List, Optional import jieba from hugegraph_llm.models.embeddings.base import BaseEmbedding +from hugegraph_llm.models.rerankers.init_reranker import Rerankers from nltk.translate.bleu_score import sentence_bleu -def get_score(query: str, content: str) -> float: +def get_blue_score(query: str, content: str) -> float: query_tokens = jieba.lcut(query) content_tokens = jieba.lcut(content) return sentence_bleu([query_tokens], content_tokens) @@ -31,44 +32,51 @@ def get_score(query: str, content: str) -> float: class MergeDedupRerank: def __init__( - self, - embedding: BaseEmbedding, - topk: int = 10, - policy: Literal["bleu", "priority"] = "bleu" + self, + embedding: BaseEmbedding, + topk: int = 20, + graph_ratio: float = 0.5, + method: Literal["bleu", "reranker"] = "bleu", + prior_vertex: Optional[str] = None, ): self.embedding = embedding + self.graph_ratio = graph_ratio self.topk = topk - if policy == "bleu": - self.rerank_func = self._bleu_rerank - elif policy == "priority": - self.rerank_func = self._priority_rerank - else: - raise ValueError(f"Unimplemented policy {policy}.") + self.method = method + self.priority_vertex = prior_vertex def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") + context["graph_ratio"] = self.graph_ratio + vector_search = context.get("vector_search", False) + graph_search = context.get("graph_search", False) + if graph_search and vector_search: + graph_length = int(self.topk * self.graph_ratio) + vector_length = self.topk - graph_length + else: + graph_length = self.topk + vector_length = self.topk + print(f"graph length {graph_length}") vector_result = context.get("vector_result", []) - vector_result = self.rerank_func(query, vector_result)[:self.topk] + vector_length = min(len(vector_result), vector_length) + vector_result = self._dedup_and_rerank(query, vector_result, vector_length) graph_result = context.get("graph_result", []) - graph_result = self.rerank_func(query, graph_result)[:self.topk] + graph_length = min(len(graph_result), graph_length) + graph_result = self._dedup_and_rerank(query, graph_result, graph_length) context["vector_result"] = vector_result context["graph_result"] = graph_result return context - def _bleu_rerank(self, query: str, results: List[str]): + def _dedup_and_rerank(self, query: str, results: List[str], topn: int) -> List[str]: results = list(set(results)) - result_score_list = [[res, get_score(query, res)] for res in results] - result_score_list.sort(key=lambda x: x[1], reverse=True) - return [res[0] for res in result_score_list] - - def _priority_rerank(self, query: str, results: List[str]): - # TODO: implement - # 1. Precise recall > Fuzzy recall - # 2. 1-degree neighbors > 2-degree neighbors - # 3. The priority of a certain type of point is higher than others, - # such as Law being higher than vehicles/people/locations - raise NotImplementedError() + if self.method == "bleu": + result_score_list = [[res, get_blue_score(query, res)] for res in results] + result_score_list.sort(key=lambda x: x[1], reverse=True) + return [res[0] for res in result_score_list][:topn] + if self.method == "reranker": + reranker = Rerankers().get_reranker() + return reranker.get_rerank_lists(query, results, topn) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index cacac618..0fdb3cb5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -35,7 +35,9 @@ class GraphRAG: - def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None): + def __init__( + self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None + ): self._llm = llm or LLMs().get_llm() self._embedding = embedding or Embeddings().get_embedding() self._operators: List[Any] = [] @@ -73,26 +75,26 @@ def extract_keyword( return self def match_keyword_to_id( - self, - by: Literal["query", "keywords"] = "keywords", - topk_per_keyword: int = 1, - topk_per_query: int = 10 + self, + by: Literal["query", "keywords"] = "keywords", + topk_per_keyword: int = 1, + topk_per_query: int = 10, ): self._operators.append( SemanticIdQuery( embedding=self._embedding, by=by, topk_per_keyword=topk_per_keyword, - topk_per_query=topk_per_query + topk_per_query=topk_per_query, ) ) return self def query_graph_for_rag( - self, - max_deep: int = 2, - max_items: int = 30, - prop_to_match: Optional[str] = None, + self, + max_deep: int = 2, + max_items: int = 30, + prop_to_match: Optional[str] = None, ): self._operators.append( GraphRAGQuery( @@ -103,10 +105,7 @@ def query_graph_for_rag( ) return self - def query_vector_index_for_rag( - self, - max_items: int = 3 - ): + def query_vector_index_for_rag(self, max_items: int = 3): self._operators.append( VectorIndexQuery( embedding=self._embedding, @@ -115,11 +114,13 @@ def query_vector_index_for_rag( ) return self - def merge_dedup_rerank(self): + def merge_dedup_rerank( + self, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + ): self._operators.append( - MergeDedupRerank( - embedding=self._embedding, - ) + MergeDedupRerank(embedding=self._embedding, graph_ratio=graph_ratio, method=rerank_method) ) return self @@ -133,10 +134,10 @@ def synthesize_answer( ): self._operators.append( AnswerSynthesize( - raw_answer = raw_answer, - vector_only_answer = vector_only_answer, - graph_only_answer = graph_only_answer, - graph_vector_answer = graph_vector_answer, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, prompt_template=prompt_template, ) ) @@ -156,7 +157,10 @@ def run(self, **kwargs) -> Dict[str, Any]: log.debug("Running operator: %s", operator.__class__.__name__) start = time.time() context = operator.run(context) - log.debug("Operator %s finished in %s seconds", operator.__class__.__name__, - time.time() - start) + log.debug( + "Operator %s finished in %s seconds", + operator.__class__.__name__, + time.time() - start, + ) log.debug("Context:\n%s", context) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index f3803c7e..ae21fa07 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -29,9 +29,8 @@ "{context_str}\n" "---------------------\n" "You need to refer to the context based on the following priority:\n" - "1. Graph recall > vector recall\n" - "2. Exact recall > Fuzzy recall\n" - "3. Independent vertex > 1-depth neighbor> 2-depth neighbors\n" + "1. Exact recall > Fuzzy recall\n" + "2. Independent vertex > 1-depth neighbor> 2-depth neighbors\n" "Given the context information and not prior knowledge, answer the query.\n" "Query: {query_str}\n" "Answer: " @@ -138,6 +137,8 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, task_cache["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" + if context.get("graph_ratio", 0.5) < 0.5: + context_body_str = f"{graph_result_context}\n{vector_result_context}" context_str = (f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n")) From e692cf4e2cefa72661b303e2253fe9984c2fc2bf Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Sat, 24 Aug 2024 12:49:59 +0800 Subject: [PATCH 02/18] fix graph context_head --- .../hugegraph_llm/operators/hugegraph_op/graph_rag_query.py | 2 +- .../src/hugegraph_llm/operators/llm_op/answer_synthesize.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 5f18d3cc..9dc70b38 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -151,7 +151,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: knowledge.update(self._format_knowledge_from_query_result(query_result=result)) context["graph_result"] = list(knowledge) - context["synthesize_context_head"] = ( + context["graph_context_head"] = ( f"The following are knowledge sequence in max depth {self._max_deep} " f"in the form of directed graph like:\n" "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...` " diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index ae21fa07..05f81a45 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -99,8 +99,10 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: graph_result_context = "There are no knowledge from HugeGraph related to the query." else: graph_result_context = ( - "The following are knowledge from HugeGraph related to the query:\n" - + "\n".join([f"{i + 1}. {res}" + context.get( + "graph_context_head", + "The following are knowledge from HugeGraph related to the query:\n" + ) + "\n".join([f"{i + 1}. {res}" for i, res in enumerate(graph_result)])) context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, vector_result_context, graph_result_context)) From 204bea206d1a88fce5fcba3b4fe4e12ac7824ec7 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Sat, 24 Aug 2024 12:57:03 +0800 Subject: [PATCH 03/18] fix ci --- hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py | 8 ++------ .../operators/hugegraph_op/graph_rag_query.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index c7d57789..3897d0ee 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -181,13 +181,9 @@ def test_api_connection( log.debug("Request URL: %s", url) try: if method.upper() == "GET": - resp = requests.get( - url, headers=headers, params=params, timeout=5, auth=auth - ) + resp = requests.get(url, headers=headers, params=params, timeout=5, auth=auth) elif method.upper() == "POST": - resp = requests.post( - url, headers=headers, params=params, json=body, timeout=5, auth=auth - ) + resp = requests.post(url, headers=headers, params=params, json=body, timeout=5, auth=auth) else: raise ValueError("Unsupported HTTP method, please use GET/POST instead") except requests.exceptions.RequestException as e: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 9dc70b38..750ade69 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -155,7 +155,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: f"The following are knowledge sequence in max depth {self._max_deep} " f"in the form of directed graph like:\n" "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...` " - "extracted based on key entities as subject:" + "extracted based on key entities as subject:\n" ) # TODO: replace print to log From 42bc9d65b2521cef279902b1e9dd52e25f1c5882 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Sun, 25 Aug 2024 11:32:05 +0800 Subject: [PATCH 04/18] Add near neighbor first and custom related information --- .../src/hugegraph_llm/demo/rag_web_demo.py | 123 ++++++------------ .../hugegraph_llm/models/rerankers/cohere.py | 25 ++-- .../models/rerankers/init_reranker.py | 9 +- .../models/rerankers/siliconflow.py | 18 +-- .../operators/common_op/merge_dedup_rerank.py | 66 +++++++++- .../hugegraph_llm/operators/graph_rag_task.py | 14 +- .../operators/hugegraph_op/graph_rag_query.py | 63 +++++---- .../operators/llm_op/answer_synthesize.py | 1 + 8 files changed, 167 insertions(+), 152 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 3897d0ee..f3ad4a4c 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -39,11 +39,7 @@ from hugegraph_llm.operators.kg_construction_task import KgBuilder from hugegraph_llm.operators.llm_op.property_graph_extract import SCHEMA_EXAMPLE_PROMPT from hugegraph_llm.utils.hugegraph_utils import get_hg_client -from hugegraph_llm.utils.hugegraph_utils import ( - init_hg_test_data, - run_gremlin_query, - clean_hg_data, -) +from hugegraph_llm.utils.hugegraph_utils import init_hg_test_data, run_gremlin_query, clean_hg_data from hugegraph_llm.utils.log import log from hugegraph_llm.utils.vector_index_utils import clean_vector_index @@ -70,6 +66,8 @@ def rag_answer( graph_vector_answer: bool, graph_ratio: float, rerank_method: str, + near_neighbor_first: bool, + custom_related_information: str, ) -> tuple: vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer @@ -82,7 +80,9 @@ def rag_answer( searcher.query_vector_index_for_rag() if graph_search: searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag() - searcher.merge_dedup_rerank(graph_ratio, rerank_method).synthesize_answer( + searcher.merge_dedup_rerank( + graph_ratio, rerank_method, near_neighbor_first, custom_related_information + ).synthesize_answer( raw_answer=raw_answer, vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, @@ -90,12 +90,7 @@ def rag_answer( ) try: - context = searcher.run( - verbose=True, - query=text, - vector_search=vector_search, - graph_search=graph_search, - ) + context = searcher.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) return ( context.get("raw_answer", ""), context.get("vector_only_answer", ""), @@ -174,9 +169,7 @@ def build_kg( # pylint: disable=too-many-branches raise gr.Error(str(e)) -def test_api_connection( - url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None -) -> int: +def test_api_connection(url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None) -> int: # TODO: use fastapi.request / starlette instead? log.debug("Request URL: %s", url) try: @@ -236,9 +229,7 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: settings.openai_embedding_model = arg3 test_url = settings.openai_api_base + "/models" headers = {"Authorization": f"Bearer {arg1}"} - status_code = test_api_connection( - test_url, headers=headers, origin_call=origin_call - ) + status_code = test_api_connection(test_url, headers=headers, origin_call=origin_call) elif embedding_option == "qianfan_wenxin": status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call) settings.qianfan_embedding_model = arg3 @@ -246,9 +237,7 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_embedding_model = arg3 - status_code = test_api_connection( - f"http://{arg1}:{arg2}", origin_call=origin_call - ) + status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) settings.update_env() gr.Info("Configured!") return status_code @@ -316,18 +305,14 @@ def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int: settings.openai_max_tokens = int(arg4) test_url = settings.openai_api_base + "/models" headers = {"Authorization": f"Bearer {arg1}"} - status_code = test_api_connection( - test_url, headers=headers, origin_call=origin_call - ) + status_code = test_api_connection(test_url, headers=headers, origin_call=origin_call) elif llm_option == "qianfan_wenxin": status_code = config_qianfan_model(arg1, arg2, arg3, origin_call) elif llm_option == "ollama": settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_language_model = arg3 - status_code = test_api_connection( - f"http://{arg1}:{arg2}", origin_call=origin_call - ) + status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) gr.Info("Configured!") settings.update_env() return status_code @@ -354,9 +339,7 @@ def init_rag_ui() -> gr.Interface: ] graph_config_button = gr.Button("apply configuration") - graph_config_button.click( - apply_graph_config, inputs=graph_config_input - ) # pylint: disable=no-member + graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member gr.Markdown("2. Set up the LLM.") llm_dropdown = gr.Dropdown( @@ -377,9 +360,7 @@ def llm_settings(llm_type): type="password", ), gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox( - value=settings.openai_language_model, label="model_name" - ), + gr.Textbox(value=settings.openai_language_model, label="model_name"), gr.Textbox(value=settings.openai_max_tokens, label="max_token"), ] elif llm_type == "ollama": @@ -387,9 +368,7 @@ def llm_settings(llm_type): llm_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox( - value=settings.ollama_language_model, label="model_name" - ), + gr.Textbox(value=settings.ollama_language_model, label="model_name"), gr.Textbox(value="", visible=False), ] elif llm_type == "qianfan_wenxin": @@ -405,9 +384,7 @@ def llm_settings(llm_type): label="secret_key", type="password", ), - gr.Textbox( - value=settings.qianfan_language_model, label="model_name" - ), + gr.Textbox(value=settings.qianfan_language_model, label="model_name"), gr.Textbox(value="", visible=False), ] log.debug(llm_config_input) @@ -415,9 +392,7 @@ def llm_settings(llm_type): llm_config_input = [] llm_config_button = gr.Button("apply configuration") - llm_config_button.click( - apply_llm_config, inputs=llm_config_input - ) # pylint: disable=no-member + llm_config_button.click(apply_llm_config, inputs=llm_config_input) # pylint: disable=no-member gr.Markdown("3. Set up the Embedding.") embedding_dropdown = gr.Dropdown( @@ -438,9 +413,7 @@ def embedding_settings(embedding_type): type="password", ), gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox( - value=settings.openai_embedding_model, label="model_name" - ), + gr.Textbox(value=settings.openai_embedding_model, label="model_name"), ] elif embedding_type == "qianfan_wenxin": with gr.Row(): @@ -455,18 +428,14 @@ def embedding_settings(embedding_type): label="secret_key", type="password", ), - gr.Textbox( - value=settings.qianfan_embedding_model, label="model_name" - ), + gr.Textbox(value=settings.qianfan_embedding_model, label="model_name"), ] elif embedding_type == "ollama": with gr.Row(): embedding_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox( - value=settings.ollama_embedding_model, label="model_name" - ), + gr.Textbox(value=settings.ollama_embedding_model, label="model_name"), ] else: embedding_config_input = [] @@ -583,9 +552,7 @@ def reranker_settings(reranker_type): file_count="multiple", ) input_schema = gr.Textbox(value=schema, label="Schema") - info_extract_template = gr.Textbox( - value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head" - ) + info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head") with gr.Column(): mode = gr.Radio( choices=[ @@ -615,33 +582,19 @@ def reranker_settings(reranker_type): show_copy_button=True, ) raw_out = gr.Textbox(label="Basic LLM Answer", show_copy_button=True) - vector_only_out = gr.Textbox( - label="Vector-only Answer", show_copy_button=True - ) - graph_only_out = gr.Textbox( - label="Graph-only Answer", show_copy_button=True - ) - graph_vector_out = gr.Textbox( - label="Graph-Vector Answer", show_copy_button=True - ) + vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True) + graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True) + graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True) with gr.Column(scale=1): - raw_radio = gr.Radio( - choices=[True, False], value=True, label="Basic LLM Answer" - ) - vector_only_radio = gr.Radio( - choices=[True, False], value=False, label="Vector-only Answer" - ) - graph_only_radio = gr.Radio( - choices=[True, False], value=False, label="Graph-only Answer" - ) + raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer") + vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") + graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer") with gr.Row(): def toggle_slider(enable): return gr.update(interactive=enable) - graph_vector_radio = gr.Radio( - choices=[True, False], value=False, label="Graph-Vector Answer" - ) + graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") graph_ratio = gr.Slider( 0, 1, @@ -650,17 +603,17 @@ def toggle_slider(enable): step=0.1, interactive=False, ) - graph_vector_radio.change( - toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio - ) + graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio) with gr.Column(): rerank_method = gr.Dropdown( choices=["bleu", "reranker"], value="bleu", label="Rerank method", ) - graph_strategy = gr.Checkbox( - value=False, label="Near neighbor first(Optional)", info="One-depth neighbors > two-depth neighbors" + near_neighbor_first = gr.Checkbox( + value=False, + label="Near neighbor first(Optional)", + info="One-depth neighbors > two-depth neighbors", ) custom_related_information = gr.Text( "", @@ -677,6 +630,8 @@ def toggle_slider(enable): graph_vector_radio, graph_ratio, rerank_method, + near_neighbor_first, + custom_related_information, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], ) @@ -692,17 +647,13 @@ def toggle_slider(enable): fmt = gr.Checkbox(label="Format JSON", value=True) out = gr.Textbox(label="Output", show_copy_button=True) btn = gr.Button("Run gremlin query on HugeGraph") - btn.click( - fn=run_gremlin_query, inputs=[inp, fmt], outputs=out - ) # pylint: disable=no-member + btn.click(fn=run_gremlin_query, inputs=[inp, fmt], outputs=out) # pylint: disable=no-member with gr.Row(): inp = [] out = gr.Textbox(label="Output", show_copy_button=True) btn = gr.Button("(BETA) Init HugeGraph test data (🚧WIP)") - btn.click( - fn=init_hg_test_data, inputs=inp, outputs=out - ) # pylint: disable=no-member + btn.click(fn=init_hg_test_data, inputs=inp, outputs=out) # pylint: disable=no-member return hugegraph_llm_ui diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py index 001fddfd..d41f5997 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py @@ -15,47 +15,38 @@ # specific language governing permissions and limitations # under the License. +from typing import Optional, List import requests -from typing import Optional -from hugegraph_llm.config import settings class CohereReranker: def __init__( self, - api_key: str = settings.reranker_api_key, - base_url: str = settings.cohere_base_url, - model: str = settings.reranker_model, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: Optional[str] = None, ): self.api_key = api_key self.base_url = base_url self.model = model - def get_rerank_lists( - self, - query: str, - documents: list[str], - top_n: Optional[int] = None - ) -> list: + def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]: if not top_n: top_n = len(documents) - assert top_n <= len( - documents - ), "'top_n' should be less than or equal to the number of documents" - + assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents" url = self.base_url headers = { "accept": "application/json", "content-type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } payload = { "model": self.model, "query": query, "top_n": top_n, - "documents": documents + "documents": documents, } response = requests.post(url, headers=headers, json=payload) response.raise_for_status() # Raise an error for bad status codes diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py index d9397ed8..4b119b60 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -19,15 +19,18 @@ from hugegraph_llm.models.rerankers.siliconflow import SiliconReranker from hugegraph_llm.config import settings + class Rerankers: def __init__(self): self.reranker_type = settings.reranker_type def get_reranker(self): if self.reranker_type == "cohere": - return CohereReranker() + return CohereReranker( + api_key=settings.reranker_api_key, base_url=settings.cohere_base_url, model=settings.reranker_model + ) if self.reranker_type == "siliconflow": - return SiliconReranker() - + return SiliconReranker(api_key=settings.reranker_api_key, model=settings.reranker_model) + raise Exception(f"reranker type is not supported !") diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py index 9fff28a2..ea96244a 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py @@ -15,30 +15,24 @@ # specific language governing permissions and limitations # under the License. - +from typing import Optional, List import requests -from typing import Optional -from hugegraph_llm.config import settings class SiliconReranker: def __init__( self, - api_key: str = settings.reranker_api_key, - model: str = settings.reranker_model, + api_key: Optional[str] = None, + model: Optional[str] = None, ): self.api_key = api_key self.model = model - def get_rerank_lists( - self, query: str, documents: list[str], top_n: Optional[int] = None - ) -> list: + def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]: if not top_n: top_n = len(documents) - assert top_n <= len( - documents - ), "'top_n' should be less than or equal to the number of documents" - + assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents" + url = "https://api.siliconflow.cn/v1/rerank" payload = { "model": self.model, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 0530b066..9816217f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -16,7 +16,7 @@ # under the License. -from typing import Literal, Dict, Any, List, Optional +from typing import Literal, Dict, Any, List, Optional, Tuple import jieba from hugegraph_llm.models.embeddings.base import BaseEmbedding @@ -24,7 +24,7 @@ from nltk.translate.bleu_score import sentence_bleu -def get_blue_score(query: str, content: str) -> float: +def get_bleu_score(query: str, content: str) -> float: query_tokens = jieba.lcut(query) content_tokens = jieba.lcut(content) return sentence_bleu([query_tokens], content_tokens) @@ -37,16 +37,24 @@ def __init__( topk: int = 20, graph_ratio: float = 0.5, method: Literal["bleu", "reranker"] = "bleu", - prior_vertex: Optional[str] = None, + near_neighbor_first: bool = False, + custom_related_information: Optional[str] = None, ): + assert method in [ + "bleu", + "reranker", + ], "rerank method should be 'bleu' or 'reranker'" self.embedding = embedding self.graph_ratio = graph_ratio self.topk = topk self.method = method - self.priority_vertex = prior_vertex + self.near_neighbor_first = near_neighbor_first + self.custom_related_information = custom_related_information def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") + if self.custom_related_information: + query = query + self.custom_related_information context["graph_ratio"] = self.graph_ratio vector_search = context.get("vector_search", False) graph_search = context.get("graph_search", False) @@ -56,7 +64,6 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: else: graph_length = self.topk vector_length = self.topk - print(f"graph length {graph_length}") vector_result = context.get("vector_result", []) vector_length = min(len(vector_result), vector_length) @@ -64,7 +71,16 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: graph_result = context.get("graph_result", []) graph_length = min(len(graph_result), graph_length) - graph_result = self._dedup_and_rerank(query, graph_result, graph_length) + if self.near_neighbor_first: + graph_result = self._rerank_with_vertex_degree( + query, + graph_result, + graph_length, + context.get("vertex_degree_list"), + context.get("knowledge_with_degree"), + ) + else: + graph_result = self._dedup_and_rerank(query, graph_result, graph_length) context["vector_result"] = vector_result context["graph_result"] = graph_result @@ -74,9 +90,45 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: def _dedup_and_rerank(self, query: str, results: List[str], topn: int) -> List[str]: results = list(set(results)) if self.method == "bleu": - result_score_list = [[res, get_blue_score(query, res)] for res in results] + result_score_list = [[res, get_bleu_score(query, res)] for res in results] result_score_list.sort(key=lambda x: x[1], reverse=True) return [res[0] for res in result_score_list][:topn] if self.method == "reranker": reranker = Rerankers().get_reranker() return reranker.get_rerank_lists(query, results, topn) + + def _rerank_with_vertex_degree( + self, + query: str, + results: List[str], + topn: int, + vertex_degree_list: List[List[str]] | None, + knowledge_with_degree: Dict[str, List[str]] | None, + ) -> List[str]: + if vertex_degree_list is None or len(vertex_degree_list) == 0: + return self._dedup_and_rerank(query, results, topn) + if self.method == "bleu": + vertex_degree_rerank_result: List[List[str]] = [] + for vertex_degree in vertex_degree_list: + vertex_degree_score_list = [[res, get_bleu_score(query, res)] for res in vertex_degree] + vertex_degree_score_list.sort(key=lambda x: x[1], reverse=True) + vertex_degree = [res[0] for res in vertex_degree_score_list] + [""] + vertex_degree_rerank_result.append(vertex_degree) + + if self.method == "reranker": + reranker = Rerankers().get_reranker() + vertex_degree_rerank_result = [ + reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list + ] + depth = len(vertex_degree_list) + for result in results: + if result not in knowledge_with_degree: + knowledge_with_degree[result] = [result] + [""] * (depth - 1) + if len(knowledge_with_degree[result]) < depth: + knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result])) + + def sort_key(result: str) -> Tuple[int, ...]: + return tuple(vertex_degree_rerank_result[i].index(knowledge_with_degree[result][i]) for i in range(depth)) + + sorted_results = sorted(results, key=sort_key) + return sorted_results[:topn] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index 0fdb3cb5..dfbbcd5d 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -35,9 +35,7 @@ class GraphRAG: - def __init__( - self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None - ): + def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None): self._llm = llm or LLMs().get_llm() self._embedding = embedding or Embeddings().get_embedding() self._operators: List[Any] = [] @@ -118,9 +116,17 @@ def merge_dedup_rerank( self, graph_ratio: float = 0.5, rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", ): self._operators.append( - MergeDedupRerank(embedding=self._embedding, graph_ratio=graph_ratio, method=rerank_method) + MergeDedupRerank( + embedding=self._embedding, + graph_ratio=graph_ratio, + method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + ) ) return self diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 750ade69..fe225c27 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -24,9 +24,7 @@ class GraphRAGQuery: - VERTEX_GREMLIN_QUERY_TEMPL = ( - "g.V().hasId({keywords}).as('subj').toList()" - ) + VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()" # ID_RAG_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV( # ).as('obj')).times({max_deep}).path().by(project('label', 'id', 'props').by(label()).by(id()).by(valueMap().by( # unfold()))).by(project('label', 'inV', 'outV', 'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap( @@ -75,10 +73,10 @@ class GraphRAGQuery: """ def __init__( - self, - max_deep: int = 2, - max_items: int = 30, - prop_to_match: Optional[str] = None, + self, + max_deep: int = 2, + max_items: int = 30, + prop_to_match: Optional[str] = None, ): self._client = PyHugeClient( settings.graph_ip, @@ -133,14 +131,16 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: edge_labels=edge_labels_str, ) result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - knowledge: Set[str] = self._format_knowledge_from_query_result(query_result=result) + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result( + query_result=result + ) else: assert entrance_vids is not None, "No entrance vertices for query." rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format( keywords=entrance_vids, ) result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - knowledge: Set[str] = self._format_knowledge_from_vertex(query_result=result) + vertex_knowledge = self._format_knowledge_from_vertex(query_result=result) rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format( keywords=entrance_vids, max_deep=self._max_deep, @@ -148,13 +148,19 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: edge_labels=edge_labels_str, ) result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - knowledge.update(self._format_knowledge_from_query_result(query_result=result)) + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result( + query_result=result + ) + graph_chain_knowledge.update(vertex_knowledge) + vertex_degree_list[0].update(vertex_knowledge) - context["graph_result"] = list(knowledge) + context["graph_result"] = list(graph_chain_knowledge) + context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list] + context["knowledge_with_degree"] = knowledge_with_degree context["graph_context_head"] = ( f"The following are knowledge sequence in max depth {self._max_deep} " f"in the form of directed graph like:\n" - "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...` " + "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...`" "extracted based on key entities as subject:\n" ) @@ -162,7 +168,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: verbose = context.get("verbose") or False if verbose: print("\033[93mKnowledge from Graph:") - print("\n".join(rel for rel in context["graph_result"]) + "\033[0m") + print("\n".join(chain for chain in context["graph_result"]) + "\033[0m") return context @@ -174,20 +180,24 @@ def _format_knowledge_from_vertex(self, query_result: List[Any]) -> Set[str]: knowledge.add(node_str) return knowledge - def _format_knowledge_from_query_result(self, query_result: List[Any]) -> Set[str]: + def _format_knowledge_from_query_result( + self, query_result: List[Any] + ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: use_id_to_match = self._prop_to_match is None knowledge = set() + knowledge_with_degree = {} + vertex_degree_list: List[Set[str]] = [] for line in query_result: flat_rel = "" raw_flat_rel = line["objects"] assert len(raw_flat_rel) % 2 == 1 node_cache = set() prior_edge_str_len = 0 + depth = 0 + nodes_with_degree = [] for i, item in enumerate(raw_flat_rel): if i % 2 == 0: - matched_str = ( - item["id"] if use_id_to_match else item["props"][self._prop_to_match] - ) + matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match] if matched_str in node_cache: flat_rel = flat_rel[:-prior_edge_str_len] break @@ -195,8 +205,14 @@ def _format_knowledge_from_query_result(self, query_result: List[Any]) -> Set[st props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items()) node_str = f"{item['id']}{{{props_str}}}" flat_rel += node_str + nodes_with_degree.append(node_str) if flat_rel in knowledge: knowledge.remove(flat_rel) + knowledge_with_degree.pop(flat_rel) + if depth >= len(vertex_degree_list): + vertex_degree_list.append(set()) + vertex_degree_list[depth].add(node_str) + depth += 1 else: props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items()) props_str = f"{{{props_str}}}" if len(props_str) > 0 else "" @@ -212,22 +228,23 @@ def _format_knowledge_from_query_result(self, query_result: List[Any]) -> Set[st flat_rel += edge_str prior_edge_str_len = len(edge_str) knowledge.add(flat_rel) - return knowledge + knowledge_with_degree[flat_rel] = nodes_with_degree + return knowledge, vertex_degree_list, knowledge_with_degree def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: schema = self._get_graph_schema() node_props_str, edge_props_str = schema.split("\n")[:2] - node_props_str = node_props_str[len("Node properties: "):].strip("[").strip("]") - edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]") + node_props_str = node_props_str[len("Node properties: ") :].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") node_labels = self._extract_label_names(node_props_str) edge_labels = self._extract_label_names(edge_props_str) return node_labels, edge_labels @staticmethod def _extract_label_names( - source: str, - head: str = "name: ", - tail: str = ", ", + source: str, + head: str = "name: ", + tail: str = ", ", ) -> List[str]: result = [] for s in source.split(head): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 05f81a45..65c68c25 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -32,6 +32,7 @@ "1. Exact recall > Fuzzy recall\n" "2. Independent vertex > 1-depth neighbor> 2-depth neighbors\n" "Given the context information and not prior knowledge, answer the query.\n" + "Answer should include as much context as possible.\n" "Query: {query_str}\n" "Answer: " ) From 029f4bff80b88acd9340a8d74ab3558d960cf795 Mon Sep 17 00:00:00 2001 From: imbajin Date: Tue, 27 Aug 2024 11:20:41 +0800 Subject: [PATCH 05/18] tiny fix --- hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py | 7 ++++--- hugegraph-llm/src/hugegraph_llm/config/config.py | 6 ++---- .../operators/common_op/merge_dedup_rerank.py | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 0b8c930c..eec10d93 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -52,14 +52,15 @@ class LLMConfigRequest(BaseModel): host: str = None port: str = None + class EmbeddingConfigRequest(BaseModel): llm_type: str - # The common parameters shared by OpenAI, Qianfan Wenxin, - # and OLLAMA platforms. + # The common parameters shared by OpenAI, Qianfan(Wenxin) & OLLAMA platforms. api_key: str + class RerankerConfigRequest(BaseModel): reranker_model: str reranker_type: str api_key: str - api_base: str \ No newline at end of file + api_base: str diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py index 778e570c..2f581f2a 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config.py @@ -55,10 +55,8 @@ class Config: qianfan_api_key: Optional[str] = None qianfan_secret_key: Optional[str] = None qianfan_access_token: Optional[str] = None - # 4.1 url settings - qianfan_url_prefix: Optional[str] = ( - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop" - ) + # 4.1 URL settings + qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop" qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/" qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/" diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 045a7251..2b7ca77a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -127,8 +127,8 @@ def _rerank_with_vertex_degree( if len(knowledge_with_degree[result]) < depth: knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result])) - def sort_key(result: str) -> Tuple[int, ...]: - return tuple(vertex_degree_rerank_result[i].index(knowledge_with_degree[result][i]) for i in range(depth)) + def sort_key(res: str) -> Tuple[int, ...]: + return tuple(vertex_degree_rerank_result[i].index(knowledge_with_degree[res][i]) for i in range(depth)) sorted_results = sorted(results, key=sort_key) return sorted_results[:topn] From 28fed504b48da11df43bd25c98b3f7e490d1999b Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Tue, 27 Aug 2024 12:18:44 +0800 Subject: [PATCH 06/18] fix format --- .../hugegraph_llm/api/models/rag_requests.py | 13 -- .../src/hugegraph_llm/api/rag_api.py | 44 +------ .../src/hugegraph_llm/demo/rag_web_demo.py | 118 ++++++------------ 3 files changed, 40 insertions(+), 135 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index eec10d93..47610f55 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -51,16 +51,3 @@ class LLMConfigRequest(BaseModel): # ollama-only properties host: str = None port: str = None - - -class EmbeddingConfigRequest(BaseModel): - llm_type: str - # The common parameters shared by OpenAI, Qianfan(Wenxin) & OLLAMA platforms. - api_key: str - - -class RerankerConfigRequest(BaseModel): - reranker_model: str - reranker_type: str - api_key: str - api_base: str diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index f7d4fc76..4ec53992 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -38,23 +38,17 @@ def rag_http_api( ): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): - result = rag_answer_func( - req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector - ) + result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector) return { key: value - for key, value in zip( - ["raw_llm", "vector_only", "graph_only", "graph_vector"], result - ) + for key, value in zip(["raw_llm", "vector_only", "graph_only", "graph_vector"], result) if getattr(req, key) } @router.post("/config/graph", status_code=status.HTTP_201_CREATED) def graph_config_api(req: GraphConfigRequest): # Accept status code - res = apply_graph_conf( - req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http" - ) + res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/llm", status_code=status.HTTP_201_CREATED) @@ -78,35 +72,5 @@ def llm_config_api(req: LLMConfigRequest): origin_call="http", ) else: - res = apply_llm_conf( - req.host, req.port, req.language_model, None, origin_call="http" - ) - return generate_response(RAGResponse(status_code=res, message="Missing Value")) - - @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) - def embedding_config_api(req: LLMConfigRequest): - settings.embedding_type = req.llm_type - - if req.llm_type == "openai": - res = apply_embedding_conf( - req.api_key, req.api_base, req.language_model, origin_call="http" - ) - elif req.llm_type == "qianfan_wenxin": - res = apply_embedding_conf( - req.api_key, req.api_base, None, origin_call="http" - ) - else: - res = apply_embedding_conf( - req.host, req.port, req.language_model, origin_call="http" - ) - return generate_response(RAGResponse(status_code=res, message="Missing Value")) - - @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) - def rerank_config_api(req: RerankerConfigRequest): - settings.reranker_type = req.reranker_type - - if req.reranker_type == "cohere": - res = apply_reranker_conf( - req.api_key, req.reranker_model, req.api_base, origin_call="http" - ) + res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 2e13dcd8..d5f9acc8 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -19,7 +19,7 @@ import argparse import json import os -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Literal import docx import gradio as gr @@ -65,10 +65,10 @@ def rag_answer( graph_only_answer: bool, graph_vector_answer: bool, graph_ratio: float, - rerank_method: str, + rerank_method: Literal["bleu", "reranker"], near_neighbor_first: bool, custom_related_information: str, - answer_prompt: str + answer_prompt: str, ) -> Tuple: vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer @@ -89,7 +89,7 @@ def rag_answer( vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, graph_vector_answer=graph_vector_answer, - answer_prompt=answer_prompt + answer_prompt=answer_prompt, ) try: @@ -215,10 +215,7 @@ def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int: "client_secret": arg2, } status_code = test_api_connection( - "https://aip.baidubce.com/oauth/2.0/token", - "POST", - params=params, - origin_call=origin_call, + "https://aip.baidubce.com/oauth/2.0/token", "POST", params=params, origin_call=origin_call ) return status_code @@ -246,26 +243,28 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: return status_code -def apply_reranker_config(arg1, arg2, arg3: str | None = None, origin_call=None) -> int: +def apply_reranker_config( + reranker_api_key, reranker_model, cohere_base_url: str | None = None, origin_call=None +) -> int: status_code = -1 reranker_option = settings.reranker_type if reranker_option == "cohere": - settings.reranker_api_key = arg1 - settings.reranker_model = arg2 - settings.cohere_base_url = arg3 - headers = {"Authorization": f"Bearer {arg1}"} + settings.reranker_api_key = reranker_api_key + settings.reranker_model = reranker_model + settings.cohere_base_url = cohere_base_url + headers = {"Authorization": f"Bearer {reranker_api_key}"} status_code = test_api_connection( - arg3.rsplit("/", 1)[0] + "/check-api-key", + cohere_base_url.rsplit("/", 1)[0] + "/check-api-key", method="POST", headers=headers, origin_call=origin_call, ) elif reranker_option == "siliconflow": - settings.reranker_api_key = arg1 - settings.reranker_model = arg2 + settings.reranker_api_key = reranker_api_key + settings.reranker_model = reranker_model headers = { "accept": "application/json", - "authorization": f"Bearer {arg1}", + "authorization": f"Bearer {reranker_api_key}", } status_code = test_api_connection( "https://api.siliconflow.cn/v1/user/info", @@ -345,11 +344,7 @@ def init_rag_ui() -> gr.Interface: graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member gr.Markdown("2. Set up the LLM.") - llm_dropdown = gr.Dropdown( - choices=["openai", "qianfan_wenxin", "ollama"], - value=settings.llm_type, - label="LLM", - ) + llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type, label="LLM") @gr.render(inputs=[llm_dropdown]) def llm_settings(llm_type): @@ -357,11 +352,7 @@ def llm_settings(llm_type): if llm_type == "openai": with gr.Row(): llm_config_input = [ - gr.Textbox( - value=settings.openai_api_key, - label="api_key", - type="password", - ), + gr.Textbox(value=settings.openai_api_key, label="api_key", type="password"), gr.Textbox(value=settings.openai_api_base, label="api_base"), gr.Textbox(value=settings.openai_language_model, label="model_name"), gr.Textbox(value=settings.openai_max_tokens, label="max_token"), @@ -377,16 +368,8 @@ def llm_settings(llm_type): elif llm_type == "qianfan_wenxin": with gr.Row(): llm_config_input = [ - gr.Textbox( - value=settings.qianfan_api_key, - label="api_key", - type="password", - ), - gr.Textbox( - value=settings.qianfan_secret_key, - label="secret_key", - type="password", - ), + gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), + gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), gr.Textbox(value=settings.qianfan_language_model, label="model_name"), gr.Textbox(value="", visible=False), ] @@ -399,9 +382,7 @@ def llm_settings(llm_type): gr.Markdown("3. Set up the Embedding.") embedding_dropdown = gr.Dropdown( - choices=["openai", "qianfan_wenxin", "ollama"], - value=settings.embedding_type, - label="Embedding", + choices=["openai", "qianfan_wenxin", "ollama"], value=settings.embedding_type, label="Embedding" ) @gr.render(inputs=[embedding_dropdown]) @@ -410,27 +391,15 @@ def embedding_settings(embedding_type): if embedding_type == "openai": with gr.Row(): embedding_config_input = [ - gr.Textbox( - value=settings.openai_api_key, - label="api_key", - type="password", - ), + gr.Textbox(value=settings.openai_api_key, label="api_key", type="password"), gr.Textbox(value=settings.openai_api_base, label="api_base"), gr.Textbox(value=settings.openai_embedding_model, label="model_name"), ] elif embedding_type == "qianfan_wenxin": with gr.Row(): embedding_config_input = [ - gr.Textbox( - value=settings.qianfan_api_key, - label="api_key", - type="password", - ), - gr.Textbox( - value=settings.qianfan_secret_key, - label="secret_key", - type="password", - ), + gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), + gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), gr.Textbox(value=settings.qianfan_embedding_model, label="model_name"), ] elif embedding_type == "ollama": @@ -453,9 +422,7 @@ def embedding_settings(embedding_type): gr.Markdown("4. Set up the Reranker(Optional).") reranker_dropdown = gr.Dropdown( - choices=["cohere", "siliconflow"], - value=settings.reranker_type, - label="Reranker", + choices=["cohere", "siliconflow"], value=settings.reranker_type, label="Reranker" ) @gr.render(inputs=[reranker_dropdown]) @@ -464,22 +431,14 @@ def reranker_settings(reranker_type): if reranker_type == "cohere": with gr.Row(): reranker_config_input = [ - gr.Textbox( - value=settings.reranker_api_key, - label="api_key", - type="password", - ), + gr.Textbox(value=settings.reranker_api_key, label="api_key", type="password"), gr.Textbox(value=settings.reranker_model, label="model"), gr.Textbox(value=settings.cohere_base_url, label="base_url"), ] elif reranker_type == "siliconflow": with gr.Row(): reranker_config_input = [ - gr.Textbox( - value=settings.reranker_api_key, - label="api_key", - type="password", - ), + gr.Textbox(value=settings.reranker_api_key, label="api_key", type="password"), gr.Textbox( value="BAAI/bge-reranker-v2-m3", label="model", @@ -552,7 +511,8 @@ def reranker_settings(reranker_type): input_file = gr.File( value=[os.path.join(resource_path, "demo", "test.txt")], label="Docs (multi-files can be selected together)", - file_count="multiple") + file_count="multiple", + ) input_schema = gr.Textbox(value=schema, label="Schema") info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head") with gr.Column(): @@ -573,11 +533,7 @@ def reranker_settings(reranker_type): gr.Markdown("""## 2. RAG with HugeGraph 📖""") with gr.Row(): with gr.Column(scale=2): - inp = gr.Textbox( - value="Tell me about Sarah.", - label="Question", - show_copy_button=True, - ) + inp = gr.Textbox(value="Tell me about Sarah.", label="Question", show_copy_button=True) raw_out = gr.Textbox(label="Basic LLM Answer", show_copy_button=True) vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True) graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True) @@ -618,8 +574,10 @@ def toggle_slider(enable): ) btn = gr.Button("Answer Question") from hugegraph_llm.operators.llm_op.answer_synthesize import DEFAULT_ANSWER_TEMPLATE - answer_prompt_input = gr.Textbox(value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", - show_copy_button=True) + + answer_prompt_input = gr.Textbox( + value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", show_copy_button=True + ) btn.click( # pylint: disable=no-member fn=rag_answer, inputs=[ @@ -640,11 +598,7 @@ def toggle_slider(enable): gr.Markdown("""## 3. Others (🚧) """) with gr.Row(): with gr.Column(): - inp = gr.Textbox( - value="g.V().limit(10)", - label="Gremlin query", - show_copy_button=True, - ) + inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True) fmt = gr.Checkbox(label="Format JSON", value=True) out = gr.Textbox(label="Output", show_copy_button=True) btn = gr.Button("Run gremlin query on HugeGraph") @@ -681,7 +635,7 @@ def toggle_slider(enable): log.info("Authentication is %s.", "enabled" if auth_enabled else "disabled") # TODO: support multi-user login when need app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", os.getenv("TOKEN")) if auth_enabled else None) - + # TODO: we can't use reload now due to the config 'app' of uvicorn.run # ❎:f'{__name__}:app' / rag_web_demo:app / hugegraph_llm.demo.rag_web_demo:app uvicorn.run(app, host=args.host, port=args.port, reload=False) From 46c072fd611e3f867312be82bff613dda6df1a8e Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Tue, 27 Aug 2024 12:36:51 +0800 Subject: [PATCH 07/18] fix api --- .../hugegraph_llm/api/models/rag_requests.py | 6 +++ .../src/hugegraph_llm/api/rag_api.py | 47 +++++++++++-------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 47610f55..5c857a1c 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -51,3 +51,9 @@ class LLMConfigRequest(BaseModel): # ollama-only properties host: str = None port: str = None + +class RerankerConfigRequest(BaseModel): + reranker_model: str + reranker_type: str + api_key: str + cohere_base_url: Optional[str] = None diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 4ec53992..a0bbdc89 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -29,12 +29,7 @@ def rag_http_api( - router: APIRouter, - rag_answer_func, - apply_graph_conf, - apply_llm_conf, - apply_embedding_conf, - apply_reranker_conf, + router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf ): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): @@ -56,21 +51,33 @@ def llm_config_api(req: LLMConfigRequest): settings.llm_type = req.llm_type if req.llm_type == "openai": - res = apply_llm_conf( - req.api_key, - req.api_base, - req.language_model, - req.max_tokens, - origin_call="http", - ) + res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http") elif req.llm_type == "qianfan_wenxin": - res = apply_llm_conf( - req.api_key, - req.secret_key, - req.language_model, - None, - origin_call="http", - ) + res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") else: res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) + def embedding_config_api(req: LLMConfigRequest): + settings.embedding_type = req.llm_type + + if req.llm_type == "openai": + res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") + elif req.llm_type == "qianfan_wenxin": + res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") + else: + res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) + def rerank_config_api(req: RerankerConfigRequest): + settings.reranker_type = req.reranker_type + + if req.reranker_type == "cohere": + res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") + elif req.reranker_type == "siliconflow": + res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") + else: + res = status.HTTP_501_NOT_IMPLEMENTED + return generate_response(RAGResponse(status_code=res, message="Missing Value")) From 2f75822aec578d64234d01a084e71cfe9382705a Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Tue, 27 Aug 2024 14:56:43 +0800 Subject: [PATCH 08/18] reserve priority --- .../hugegraph_llm/operators/common_op/merge_dedup_rerank.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 2b7ca77a..ea32240a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -39,6 +39,7 @@ def __init__( method: Literal["bleu", "reranker"] = "bleu", near_neighbor_first: bool = False, custom_related_information: Optional[str] = None, + priority: bool = False, # TODO: implement priority ): assert method in [ "bleu", @@ -50,6 +51,8 @@ def __init__( self.method = method self.near_neighbor_first = near_neighbor_first self.custom_related_information = custom_related_information + if priority: + raise ValueError(f"Unimplemented rerank strategy: priority.") def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") From 86090e53b3b04fa46c60aa386e8a4b97d007b817 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Tue, 27 Aug 2024 17:11:45 +0800 Subject: [PATCH 09/18] fix prompt --- .../src/hugegraph_llm/operators/llm_op/answer_synthesize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index f94dbaaa..f8109f1f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -22,7 +22,7 @@ from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs -# TODO: we need enhance the template to answer the question +# TODO: we need enhance the template to answer the question (put it in a separate file) DEFAULT_ANSWER_TEMPLATE = f""" You are an expert in knowledge graphs and natural language processing. Your task is to provide a precise and accurate answer based on the given context. @@ -31,8 +31,6 @@ --------------------- {{context_str}} --------------------- -Please refer to the context based on the following priority: -1. Precise data > fuzzy data Given the context information and without using fictive knowledge, answer the following query in a concise and professional manner. From f14cbb039aeb8944818371bfa4ce8c0cf2284f67 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Fri, 30 Aug 2024 22:02:13 +0800 Subject: [PATCH 10/18] fix ui --- .../src/hugegraph_llm/config/config.py | 2 +- .../src/hugegraph_llm/demo/rag_web_demo.py | 53 ++++++++++--------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py index 2f581f2a..b9d917f9 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config.py @@ -35,7 +35,7 @@ class Config: # env_path: Optional[str] = ".env" llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai" embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai" - reranker_type: Optional[Literal["cohere", "siliconflow"]] = "cohere" + reranker_type: Optional[Literal["cohere", "siliconflow"]] = None # 1. OpenAI settings openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 266e3f0a..2aff21c0 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -539,31 +539,38 @@ def reranker_settings(reranker_type): vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True) graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True) graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True) + from hugegraph_llm.operators.llm_op.answer_synthesize import DEFAULT_ANSWER_TEMPLATE + + answer_prompt_input = gr.Textbox( + value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", show_copy_button=True + ) with gr.Column(scale=1): - raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer") - vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") - graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer") with gr.Row(): + raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer") + vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") + with gr.Row(): + graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer") + graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") - def toggle_slider(enable): - return gr.update(interactive=enable) + def toggle_slider(enable): + return gr.update(interactive=enable) - graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") - graph_ratio = gr.Slider( - 0, - 1, - 0.5, - label="Graph Ratio", - step=0.1, - interactive=False, - ) - graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio) with gr.Column(): - rerank_method = gr.Dropdown( - choices=["bleu", "reranker"], - value="bleu", - label="Rerank method", - ) + with gr.Row(): + rerank_method = gr.Dropdown( + choices=["bleu", "reranker"] if settings.reranker_type else ["bleu"], + value="bleu", + label="Rerank method", + ) + graph_ratio = gr.Slider( + 0, + 1, + 0.5, + label="Graph Ratio", + step=0.1, + interactive=False, + ) + graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio) near_neighbor_first = gr.Checkbox( value=False, label="Near neighbor first(Optional)", @@ -573,12 +580,8 @@ def toggle_slider(enable): "", label="Custom related information(Optional)", ) - btn = gr.Button("Answer Question") - from hugegraph_llm.operators.llm_op.answer_synthesize import DEFAULT_ANSWER_TEMPLATE + btn = gr.Button("Answer Question") - answer_prompt_input = gr.Textbox( - value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", show_copy_button=True - ) btn.click( # pylint: disable=no-member fn=rag_answer, inputs=[ From cefe11a196b8380ce93271f0e2447b15808552a8 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Sun, 1 Sep 2024 23:29:16 +0800 Subject: [PATCH 11/18] answer button primary --- hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 2aff21c0..9a0e7e8b 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -580,7 +580,7 @@ def toggle_slider(enable): "", label="Custom related information(Optional)", ) - btn = gr.Button("Answer Question") + btn = gr.Button("Answer Question", variant="primary") btn.click( # pylint: disable=no-member fn=rag_answer, From 3100c2b60c545ec8db146e07f06911596533d204 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Tue, 3 Sep 2024 21:20:19 +0800 Subject: [PATCH 12/18] auto switch to bleu --- .../src/hugegraph_llm/demo/rag_web_demo.py | 2 ++ .../operators/common_op/merge_dedup_rerank.py | 24 ++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 9a0e7e8b..52535a55 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -94,6 +94,8 @@ def rag_answer( try: context = searcher.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) + if context.get("switch_to_bleu"): + gr.Warning("Online reranker fails, automatically switches to local bleu method.") return ( context.get("raw_answer", ""), context.get("vector_only_answer", ""), diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index ea32240a..e4fa0d3e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -21,7 +21,9 @@ import jieba from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.rerankers.init_reranker import Rerankers +from hugegraph_llm.utils.log import log from nltk.translate.bleu_score import sentence_bleu +import requests def get_bleu_score(query: str, content: str) -> float: @@ -53,6 +55,7 @@ def __init__( self.custom_related_information = custom_related_information if priority: raise ValueError(f"Unimplemented rerank strategy: priority.") + self.switch_to_bleu = False def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") @@ -82,6 +85,8 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context.get("vertex_degree_list"), context.get("knowledge_with_degree"), ) + if self.switch_to_bleu: + context["switch_to_bleu"] = True else: graph_result = self._dedup_and_rerank(query, graph_result, graph_length) @@ -106,10 +111,22 @@ def _rerank_with_vertex_degree( results: List[str], topn: int, vertex_degree_list: List[List[str]] | None, - knowledge_with_degree: Dict[str, List[str]] | None, + knowledge_with_degree: Dict[str, List[str]], ) -> List[str]: if vertex_degree_list is None or len(vertex_degree_list) == 0: return self._dedup_and_rerank(query, results, topn) + + if self.method == "reranker": + reranker = Rerankers().get_reranker() + try: + vertex_degree_rerank_result = [ + reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list + ] + except requests.exceptions.RequestException as e: + log.warning(f"Online reranker fails, automatically switches to local bleu method: {e}") + self.method = "bleu" + self.switch_to_bleu = True + if self.method == "bleu": vertex_degree_rerank_result: List[List[str]] = [] for vertex_degree in vertex_degree_list: @@ -118,11 +135,6 @@ def _rerank_with_vertex_degree( vertex_degree = [res[0] for res in vertex_degree_score_list] + [""] vertex_degree_rerank_result.append(vertex_degree) - if self.method == "reranker": - reranker = Rerankers().get_reranker() - vertex_degree_rerank_result = [ - reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list - ] depth = len(vertex_degree_list) for result in results: if result not in knowledge_with_degree: From 611eb5ede6d2074faca47247af8d9813520b3597 Mon Sep 17 00:00:00 2001 From: imbajin Date: Wed, 4 Sep 2024 15:07:29 +0800 Subject: [PATCH 13/18] tiny format --- hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py | 1 + hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 3 ++- hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py | 6 +++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index ce3b464e..a211bb8b 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -53,6 +53,7 @@ class LLMConfigRequest(BaseModel): host: str = None port: str = None + class RerankerConfigRequest(BaseModel): reranker_model: str reranker_type: str diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index f5c0a599..64daf707 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -33,7 +33,8 @@ def rag_http_api( ): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): - result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector, req.answer_prompt) + result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector, + req.answer_prompt) return { key: value for key, value in zip(["raw_llm", "vector_only", "graph_only", "graph_vector"], result) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 52535a55..ecd09c73 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -376,7 +376,7 @@ def llm_settings(llm_type): gr.Textbox(value=settings.qianfan_language_model, label="model_name"), gr.Textbox(value="", visible=False), ] - log.debug(llm_config_input) + # log.debug(llm_config_input) else: llm_config_input = [] llm_config_button = gr.Button("apply configuration") @@ -419,7 +419,7 @@ def embedding_settings(embedding_type): # Call the separate apply_embedding_configuration function here embedding_config_button.click( # pylint: disable=no-member - apply_embedding_config, + fn=apply_embedding_config, inputs=embedding_config_input, # pylint: disable=no-member ) @@ -455,7 +455,7 @@ def reranker_settings(reranker_type): # Call the separate apply_reranker_configuration function here reranker_config_button.click( # pylint: disable=no-member - apply_reranker_config, + fn=apply_reranker_config, inputs=reranker_config_input, # pylint: disable=no-member ) From 5ed2dac47addc35105849b35aff45ebe56431c35 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Wed, 4 Sep 2024 15:22:04 +0800 Subject: [PATCH 14/18] fix config --- hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index ecd09c73..f8d1e17b 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -250,7 +250,7 @@ def apply_reranker_config( reranker_api_key, reranker_model, cohere_base_url: str | None = None, origin_call=None ) -> int: status_code = -1 - reranker_option = settings.reranker_type + reranker_option = settings.reranker_type if settings.reranker_type else "cohere" if reranker_option == "cohere": settings.reranker_api_key = reranker_api_key settings.reranker_model = reranker_model From b5c170c0c57f191807655ddb7545a59158ebb028 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Wed, 4 Sep 2024 16:19:47 +0800 Subject: [PATCH 15/18] fix user waring about config --- hugegraph-llm/src/hugegraph_llm/config/config.py | 2 +- .../src/hugegraph_llm/demo/rag_web_demo.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py index b9d917f9..9df29818 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config.py @@ -45,7 +45,7 @@ class Config: # 2. Rerank settings cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank") reranker_api_key: Optional[str] = None - reranker_model: Optional[str] = "rerank-multilingual-v3.0" + reranker_model: Optional[str] = None # 3. Ollama settings ollama_host: Optional[str] = "127.0.0.1" ollama_port: Optional[int] = 11434 diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index f8d1e17b..5f5ce09d 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -203,7 +203,7 @@ def test_api_connection(url, method="GET", headers=None, params=None, body=None, if origin_call is None: try: raise gr.Error(json.loads(resp.text).get("message", msg)) - except json.decoder.JSONDecodeError: + except json.decoder.JSONDecodeError and AttributeError: raise gr.Error(resp.text) return resp.status_code @@ -247,10 +247,13 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: def apply_reranker_config( - reranker_api_key, reranker_model, cohere_base_url: str | None = None, origin_call=None + reranker_api_key: str | None = None, + reranker_model: str | None = None, + cohere_base_url: str | None = None, + origin_call=None, ) -> int: status_code = -1 - reranker_option = settings.reranker_type if settings.reranker_type else "cohere" + reranker_option = settings.reranker_type if reranker_option == "cohere": settings.reranker_api_key = reranker_api_key settings.reranker_model = reranker_model @@ -425,12 +428,14 @@ def embedding_settings(embedding_type): gr.Markdown("4. Set up the Reranker(Optional).") reranker_dropdown = gr.Dropdown( - choices=["cohere", "siliconflow"], value=settings.reranker_type, label="Reranker" + choices=["cohere", "siliconflow", "None"], + value=settings.reranker_type if settings.reranker_type else "None", + label="Reranker", ) @gr.render(inputs=[reranker_dropdown]) def reranker_settings(reranker_type): - settings.reranker_type = reranker_type + settings.reranker_type = reranker_type if reranker_type != "None" else None if reranker_type == "cohere": with gr.Row(): reranker_config_input = [ From 6c3176916203773808afa199918bf9d8a7498438 Mon Sep 17 00:00:00 2001 From: jasinliu <939282975@qq.com> Date: Wed, 4 Sep 2024 17:09:43 +0800 Subject: [PATCH 16/18] fix type error --- hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py | 8 ++++---- .../operators/common_op/merge_dedup_rerank.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 5f5ce09d..dc95485b 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -19,7 +19,7 @@ import argparse import json import os -from typing import List, Union, Tuple, Literal +from typing import List, Union, Tuple, Literal, Optional import docx import gradio as gr @@ -247,9 +247,9 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: def apply_reranker_config( - reranker_api_key: str | None = None, - reranker_model: str | None = None, - cohere_base_url: str | None = None, + reranker_api_key: Optional[str] = None, + reranker_model: Optional[str] = None, + cohere_base_url: Optional[str] = None, origin_call=None, ) -> int: status_code = -1 diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index e4fa0d3e..780cfb4d 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -110,7 +110,7 @@ def _rerank_with_vertex_degree( query: str, results: List[str], topn: int, - vertex_degree_list: List[List[str]] | None, + vertex_degree_list: Optional[List[List[str]]], knowledge_with_degree: Dict[str, List[str]], ) -> List[str]: if vertex_degree_list is None or len(vertex_degree_list) == 0: From 4b4cb59f2a6af1e438328339a660b8ccb71b986d Mon Sep 17 00:00:00 2001 From: imbajin Date: Wed, 4 Sep 2024 19:54:50 +0800 Subject: [PATCH 17/18] refact bleu rerank & support rank ui mapping in gradio --- .../src/hugegraph_llm/config/__init__.py | 2 +- .../src/hugegraph_llm/config/config.py | 2 +- .../src/hugegraph_llm/demo/rag_web_demo.py | 21 +++++-------- .../models/rerankers/init_reranker.py | 9 +++--- .../models/rerankers/siliconflow.py | 1 + .../operators/common_op/merge_dedup_rerank.py | 31 +++++++++---------- 6 files changed, 29 insertions(+), 37 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/config/__init__.py b/hugegraph-llm/src/hugegraph_llm/config/__init__.py index f801b887..3e6c9e97 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/__init__.py +++ b/hugegraph-llm/src/hugegraph_llm/config/__init__.py @@ -22,8 +22,8 @@ ] import os -from .config import Config +from .config import Config settings = Config() settings.from_env() diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py index 9df29818..2a73b622 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config.py @@ -60,7 +60,7 @@ class Config: qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/" qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/" - # https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu + # refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details qianfan_embedding_model: Optional[str] = "embedding-v1" # 5. ZhiPu(GLM) settings zhipu_api_key: Optional[str] = None diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index dc95485b..2414c431 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -426,10 +426,10 @@ def embedding_settings(embedding_type): inputs=embedding_config_input, # pylint: disable=no-member ) - gr.Markdown("4. Set up the Reranker(Optional).") + gr.Markdown("4. Set up the Reranker (Optional).") reranker_dropdown = gr.Dropdown( - choices=["cohere", "siliconflow", "None"], - value=settings.reranker_type if settings.reranker_type else "None", + choices=["cohere", "siliconflow", ("default/offline", "None")], + value=os.getenv("reranker_type") or "None", label="Reranker", ) @@ -564,19 +564,14 @@ def toggle_slider(enable): with gr.Column(): with gr.Row(): + online_rerank = os.getenv("reranker_type") rerank_method = gr.Dropdown( - choices=["bleu", "reranker"] if settings.reranker_type else ["bleu"], - value="bleu", + choices=["bleu", ("rerank (online)", "reranker")] if online_rerank else ["bleu"], + value="reranker" if online_rerank else "bleu", label="Rerank method", ) - graph_ratio = gr.Slider( - 0, - 1, - 0.5, - label="Graph Ratio", - step=0.1, - interactive=False, - ) + graph_ratio = gr.Slider(0, 1, 0.5, label="Graph Ratio", step=0.1, interactive=False) + graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio) near_neighbor_first = gr.Checkbox( value=False, diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py index 4b119b60..541f4130 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. +from hugegraph_llm.config import settings from hugegraph_llm.models.rerankers.cohere import CohereReranker from hugegraph_llm.models.rerankers.siliconflow import SiliconReranker -from hugegraph_llm.config import settings class Rerankers: @@ -29,8 +29,7 @@ def get_reranker(self): return CohereReranker( api_key=settings.reranker_api_key, base_url=settings.cohere_base_url, model=settings.reranker_model ) - - if self.reranker_type == "siliconflow": + elif self.reranker_type == "siliconflow": return SiliconReranker(api_key=settings.reranker_api_key, model=settings.reranker_model) - - raise Exception(f"reranker type is not supported !") + else: + raise Exception(f"reranker type is not supported !") diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py index 6e9c5be5..a860a842 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py @@ -16,6 +16,7 @@ # under the License. from typing import Optional, List + import requests diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 780cfb4d..6e356e25 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -19,11 +19,12 @@ from typing import Literal, Dict, Any, List, Optional, Tuple import jieba +import requests +from nltk.translate.bleu_score import sentence_bleu + from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.rerankers.init_reranker import Rerankers from hugegraph_llm.utils.log import log -from nltk.translate.bleu_score import sentence_bleu -import requests def get_bleu_score(query: str, content: str) -> float: @@ -32,6 +33,12 @@ def get_bleu_score(query: str, content: str) -> float: return sentence_bleu([query_tokens], content_tokens) +def _bleu_rerank(query: str, results: List[str]) -> List[str]: + result_score_list = [[res, get_bleu_score(query, res)] for res in results] + result_score_list.sort(key=lambda x: x[1], reverse=True) + return [res[0] for res in result_score_list] + + class MergeDedupRerank: def __init__( self, @@ -43,10 +50,7 @@ def __init__( custom_related_information: Optional[str] = None, priority: bool = False, # TODO: implement priority ): - assert method in [ - "bleu", - "reranker", - ], f"Unimplemented rerank method '{method}'." + assert method in ["bleu", "reranker"], f"Unimplemented rerank method '{method}'." self.embedding = embedding self.graph_ratio = graph_ratio self.topk = topk @@ -98,9 +102,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: def _dedup_and_rerank(self, query: str, results: List[str], topn: int) -> List[str]: results = list(set(results)) if self.method == "bleu": - result_score_list = [[res, get_bleu_score(query, res)] for res in results] - result_score_list.sort(key=lambda x: x[1], reverse=True) - return [res[0] for res in result_score_list][:topn] + return _bleu_rerank(query, results)[:topn] if self.method == "reranker": reranker = Rerankers().get_reranker() return reranker.get_rerank_lists(query, results, topn) @@ -119,7 +121,7 @@ def _rerank_with_vertex_degree( if self.method == "reranker": reranker = Rerankers().get_reranker() try: - vertex_degree_rerank_result = [ + vertex_rerank_res = [ reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list ] except requests.exceptions.RequestException as e: @@ -128,12 +130,7 @@ def _rerank_with_vertex_degree( self.switch_to_bleu = True if self.method == "bleu": - vertex_degree_rerank_result: List[List[str]] = [] - for vertex_degree in vertex_degree_list: - vertex_degree_score_list = [[res, get_bleu_score(query, res)] for res in vertex_degree] - vertex_degree_score_list.sort(key=lambda x: x[1], reverse=True) - vertex_degree = [res[0] for res in vertex_degree_score_list] + [""] - vertex_degree_rerank_result.append(vertex_degree) + vertex_rerank_res = [_bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list] depth = len(vertex_degree_list) for result in results: @@ -143,7 +140,7 @@ def _rerank_with_vertex_degree( knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result])) def sort_key(res: str) -> Tuple[int, ...]: - return tuple(vertex_degree_rerank_result[i].index(knowledge_with_degree[res][i]) for i in range(depth)) + return tuple(vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth)) sorted_results = sorted(results, key=sort_key) return sorted_results[:topn] From 034c80d28e5e5651c8d5d9b0d2467b259e821f0c Mon Sep 17 00:00:00 2001 From: imbajin Date: Wed, 4 Sep 2024 23:49:57 +0800 Subject: [PATCH 18/18] refact code --- .../src/hugegraph_llm/demo/rag_web_demo.py | 1 + .../operators/llm_op/answer_synthesize.py | 37 +++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index 2414c431..20ed5b01 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -458,6 +458,7 @@ def reranker_settings(reranker_type): reranker_config_button = gr.Button("apply configuration") + # TODO: use "gr.update()" or other way to update the config in time (refactor the click event) # Call the separate apply_reranker_configuration function here reranker_config_button.click( # pylint: disable=no-member fn=apply_reranker_config, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 60485112..2d05160b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -91,21 +91,20 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vector_result = context.get("vector_result", []) if len(vector_result) == 0: - vector_result_context = "There are no paragraphs related to the query." + vector_result_context = "No (vector)phrase related to the query." else: - vector_result_context = ("The following are paragraphs related to the query:\n" - + "\n".join([f"{i + 1}. {res}" - for i, res in enumerate(vector_result)])) + vector_result_context = "Phrases related to the query:\n" + "\n".join( + f"{i + 1}. {res}" for i, res in enumerate(vector_result) + ) graph_result = context.get("graph_result", []) if len(graph_result) == 0: - graph_result_context = "There are no knowledge from HugeGraph related to the query." + graph_result_context = "No knowledge found in HugeGraph for the query." else: - graph_result_context = ( - context.get( - "graph_context_head", - "The following are knowledge from HugeGraph related to the query:\n" - ) + "\n".join([f"{i + 1}. {res}" - for i, res in enumerate(graph_result)])) + graph_context_head = context.get("graph_context_head", + "The following are knowledge from HugeGraph related to the query:\n") + graph_result_context = graph_context_head + "\n".join( + f"{i + 1}. {res}" for i, res in enumerate(graph_result) + ) context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, vector_result_context, graph_result_context)) @@ -115,6 +114,7 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, context_tail_str: str, vector_result_context: str, graph_result_context: str): verbose = context.get("verbose") or False + # TODO: replace task_cache with a better name task_cache = {} if self._raw_answer: prompt = self._question @@ -124,20 +124,14 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, f"{vector_result_context}\n" f"{context_tail_str}".strip("\n")) - prompt = self._prompt_template.format( - context_str=context_str, - query_str=self._question, - ) + prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) task_cache["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._graph_only_answer: context_str = (f"{context_head_str}\n" f"{graph_result_context}\n" f"{context_tail_str}".strip("\n")) - prompt = self._prompt_template.format( - context_str=context_str, - query_str=self._question, - ) + prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) task_cache["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt)) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" @@ -147,10 +141,7 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, f"{context_body_str}\n" f"{context_tail_str}".strip("\n")) - prompt = self._prompt_template.format( - context_str=context_str, - query_str=self._question, - ) + prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) task_cache["graph_vector_task"] = asyncio.create_task( self._llm.agenerate(prompt=prompt) )