From cb6e9ce1645c0a0e36875d94c15b1f6061c4ebc3 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 17 Dec 2024 09:48:03 +0800 Subject: [PATCH] Cache the result from llm for graphrag and raptor (#4051) ### What problem does this PR solve? #4045 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/task_service.py | 8 +++- graphrag/__init__.py | 0 graphrag/claim_extractor.py | 10 ++--- graphrag/community_reports_extractor.py | 6 +-- graphrag/description_summary.py | 6 +-- graphrag/entity_resolution.py | 7 ++-- graphrag/extractor.py | 34 ++++++++++++++++ graphrag/graph_extractor.py | 15 +++---- graphrag/mind_map_extractor.py | 6 +-- graphrag/utils.py | 52 +++++++++++++++++++++++++ rag/raptor.py | 28 +++++++++++-- rag/svr/task_executor.py | 27 ++++++++++--- 12 files changed, 161 insertions(+), 38 deletions(-) create mode 100644 graphrag/__init__.py create mode 100644 graphrag/extractor.py diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index a39a4e84f86..4f4ec30ec8a 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -271,7 +271,7 @@ def new_task(): def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): - idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"]) + idx = bisect.bisect_left(prev_tasks, task.get("from_page", 0), key=lambda x: x.get("from_page",0)) if idx >= len(prev_tasks): return 0 prev_task = prev_tasks[idx] @@ -279,7 +279,11 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: return 0 task["chunk_ids"] = prev_task["chunk_ids"] task["progress"] = 1.0 - task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks" + if "from_page" in task and "to_page" in task: + task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): " + else: + task["progress_msg"] = "" + task["progress_msg"] += "reused previous task's chunks." prev_task["chunk_ids"] = "" return len(task["chunk_ids"].split()) \ No newline at end of file diff --git a/graphrag/__init__.py b/graphrag/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/graphrag/claim_extractor.py b/graphrag/claim_extractor.py index b202e50a221..8b5c8d65def 100644 --- a/graphrag/claim_extractor.py +++ b/graphrag/claim_extractor.py @@ -16,6 +16,7 @@ import tiktoken from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT +from graphrag.extractor import Extractor from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import ErrorHandlerFn, perform_variable_replacements @@ -33,10 +34,9 @@ class ClaimExtractorResult: source_docs: dict[str, Any] -class ClaimExtractor: +class ClaimExtractor(Extractor): """Claim extractor class definition.""" - _llm: CompletionLLM _extraction_prompt: str _summary_prompt: str _output_formatter_prompt: str @@ -169,7 +169,7 @@ def _process_document( } text = perform_variable_replacements(self._extraction_prompt, variables=variables) gen_conf = {"temperature": 0.5} - results = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) + results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) claims = results.strip().removesuffix(completion_delimiter) history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}] @@ -177,7 +177,7 @@ def _process_document( for i in range(self._max_gleanings): text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) history.append({"role": "user", "content": text}) - extension = self._llm.chat("", history, gen_conf) + extension = self._chat("", history, gen_conf) claims += record_delimiter + extension.strip().removesuffix( completion_delimiter ) @@ -188,7 +188,7 @@ def _process_document( history.append({"role": "assistant", "content": extension}) history.append({"role": "user", "content": LOOP_PROMPT}) - continuation = self._llm.chat("", history, self._loop_args) + continuation = self._chat("", history, self._loop_args) if continuation != "YES": break diff --git a/graphrag/community_reports_extractor.py b/graphrag/community_reports_extractor.py index 756a7811eb9..19ed994d263 100644 --- a/graphrag/community_reports_extractor.py +++ b/graphrag/community_reports_extractor.py @@ -15,6 +15,7 @@ import pandas as pd from graphrag import leiden from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT +from graphrag.extractor import Extractor from graphrag.leiden import add_community_info2graph from rag.llm.chat_model import Base as CompletionLLM from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types @@ -30,10 +31,9 @@ class CommunityReportsResult: structured_output: list[dict] -class CommunityReportsExtractor: +class CommunityReportsExtractor(Extractor): """Community reports extractor class definition.""" - _llm: CompletionLLM _extraction_prompt: str _output_formatter_prompt: str _on_error: ErrorHandlerFn @@ -74,7 +74,7 @@ def __call__(self, graph: nx.Graph, callback: Callable | None = None): text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) gen_conf = {"temperature": 0.3} try: - response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) + response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) token_count += num_tokens_from_string(text + response) response = re.sub(r"^[^\{]*", "", response) response = re.sub(r"[^\}]*$", "", response) diff --git a/graphrag/description_summary.py b/graphrag/description_summary.py index f5537c95447..e226f26fb05 100644 --- a/graphrag/description_summary.py +++ b/graphrag/description_summary.py @@ -8,6 +8,7 @@ import json from dataclasses import dataclass +from graphrag.extractor import Extractor from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from rag.llm.chat_model import Base as CompletionLLM @@ -42,10 +43,9 @@ class SummarizationResult: description: str -class SummarizeExtractor: +class SummarizeExtractor(Extractor): """Unipartite graph extractor class definition.""" - _llm: CompletionLLM _entity_name_key: str _input_descriptions_key: str _summarization_prompt: str @@ -143,4 +143,4 @@ def _summarize_descriptions_with_llm( self._input_descriptions_key: json.dumps(sorted(descriptions)), } text = perform_variable_replacements(self._summarization_prompt, variables=variables) - return self._llm.chat("", [{"role": "user", "content": text}]) + return self._chat("", [{"role": "user", "content": text}]) diff --git a/graphrag/entity_resolution.py b/graphrag/entity_resolution.py index 1c8a4b4e0fe..78295823480 100644 --- a/graphrag/entity_resolution.py +++ b/graphrag/entity_resolution.py @@ -21,6 +21,8 @@ from typing import Any import networkx as nx + +from graphrag.extractor import Extractor from rag.nlp import is_english import editdistance from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT @@ -39,10 +41,9 @@ class EntityResolutionResult: output: nx.Graph -class EntityResolution: +class EntityResolution(Extractor): """Entity resolution class definition.""" - _llm: CompletionLLM _resolution_prompt: str _output_formatter_prompt: str _on_error: ErrorHandlerFn @@ -117,7 +118,7 @@ def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = No } text = perform_variable_replacements(self._resolution_prompt, variables=variables) - response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) + response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) result = self._process_results(len(candidate_resolution_i[1]), response, prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), diff --git a/graphrag/extractor.py b/graphrag/extractor.py new file mode 100644 index 00000000000..552225c6447 --- /dev/null +++ b/graphrag/extractor.py @@ -0,0 +1,34 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from graphrag.utils import get_llm_cache, set_llm_cache +from rag.llm.chat_model import Base as CompletionLLM + + +class Extractor: + _llm: CompletionLLM + + def __init__(self, llm_invoker: CompletionLLM): + self._llm = llm_invoker + + def _chat(self, system, history, gen_conf): + response = get_llm_cache(self._llm.llm_name, system, history, gen_conf) + if response: + return response + response = self._llm.chat(system, history, gen_conf) + if response.find("**ERROR**") >= 0: + raise Exception(response) + set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) + return response diff --git a/graphrag/graph_extractor.py b/graphrag/graph_extractor.py index 290390ac9c4..e20d22c910e 100644 --- a/graphrag/graph_extractor.py +++ b/graphrag/graph_extractor.py @@ -12,6 +12,8 @@ from typing import Any, Callable, Mapping from dataclasses import dataclass import tiktoken + +from graphrag.extractor import Extractor from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str from rag.llm.chat_model import Base as CompletionLLM @@ -34,10 +36,9 @@ class GraphExtractionResult: source_docs: dict[Any, Any] -class GraphExtractor: +class GraphExtractor(Extractor): """Unipartite graph extractor class definition.""" - _llm: CompletionLLM _join_descriptions: bool _tuple_delimiter_key: str _record_delimiter_key: str @@ -165,9 +166,7 @@ def _process_document( token_count = 0 text = perform_variable_replacements(self._extraction_prompt, variables=variables) gen_conf = {"temperature": 0.3} - response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) - if response.find("**ERROR**") >= 0: - raise Exception(response) + response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) token_count = num_tokens_from_string(text + response) results = response or "" @@ -177,9 +176,7 @@ def _process_document( for i in range(self._max_gleanings): text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) history.append({"role": "user", "content": text}) - response = self._llm.chat("", history, gen_conf) - if response.find("**ERROR**") >=0: - raise Exception(response) + response = self._chat("", history, gen_conf) results += response or "" # if this is the final glean, don't bother updating the continuation flag @@ -187,7 +184,7 @@ def _process_document( break history.append({"role": "assistant", "content": response}) history.append({"role": "user", "content": LOOP_PROMPT}) - continuation = self._llm.chat("", history, self._loop_args) + continuation = self._chat("", history, self._loop_args) if continuation != "YES": break diff --git a/graphrag/mind_map_extractor.py b/graphrag/mind_map_extractor.py index 74d396dbf55..f3c427d8dbf 100644 --- a/graphrag/mind_map_extractor.py +++ b/graphrag/mind_map_extractor.py @@ -23,6 +23,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from graphrag.extractor import Extractor from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from rag.llm.chat_model import Base as CompletionLLM @@ -37,8 +38,7 @@ class MindMapResult: output: dict -class MindMapExtractor: - _llm: CompletionLLM +class MindMapExtractor(Extractor): _input_text_key: str _mind_map_prompt: str _on_error: ErrorHandlerFn @@ -190,7 +190,7 @@ def _process_document( } text = perform_variable_replacements(self._mind_map_prompt, variables=variables) gen_conf = {"temperature": 0.5} - response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) + response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf) response = re.sub(r"```[^\n]*", "", response) logging.debug(response) logging.debug(self._todict(markdown_to_json.dictify(response))) diff --git a/graphrag/utils.py b/graphrag/utils.py index 3a8c5253f5f..bed0dcdae74 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -6,9 +6,15 @@ """ import html +import json import re from typing import Any, Callable +import numpy as np +import xxhash + +from rag.utils.redis_conn import REDIS_CONN + ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] @@ -60,3 +66,49 @@ def dict_has_keys_with_types( return False return True + +def get_llm_cache(llmnm, txt, history, genconf): + hasher = xxhash.xxh64() + hasher.update(str(llmnm).encode("utf-8")) + hasher.update(str(txt).encode("utf-8")) + hasher.update(str(history).encode("utf-8")) + hasher.update(str(genconf).encode("utf-8")) + + k = hasher.hexdigest() + bin = REDIS_CONN.get(k) + if not bin: + return + return bin.decode("utf-8") + + +def set_llm_cache(llmnm, txt, v: str, history, genconf): + hasher = xxhash.xxh64() + hasher.update(str(llmnm).encode("utf-8")) + hasher.update(str(txt).encode("utf-8")) + hasher.update(str(history).encode("utf-8")) + hasher.update(str(genconf).encode("utf-8")) + + k = hasher.hexdigest() + REDIS_CONN.set(k, v.encode("utf-8"), 24*3600) + + +def get_embed_cache(llmnm, txt): + hasher = xxhash.xxh64() + hasher.update(str(llmnm).encode("utf-8")) + hasher.update(str(txt).encode("utf-8")) + + k = hasher.hexdigest() + bin = REDIS_CONN.get(k) + if not bin: + return + return np.array(json.loads(bin.decode("utf-8"))) + + +def set_embed_cache(llmnm, txt, arr): + hasher = xxhash.xxh64() + hasher.update(str(llmnm).encode("utf-8")) + hasher.update(str(txt).encode("utf-8")) + + k = hasher.hexdigest() + arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) + REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600) \ No newline at end of file diff --git a/rag/raptor.py b/rag/raptor.py index 6e11cf683e5..5da3d0864b3 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -21,6 +21,7 @@ import numpy as np from sklearn.mixture import GaussianMixture +from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache from rag.utils import truncate @@ -33,6 +34,27 @@ def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, th self._prompt = prompt self._max_token = max_token + def _chat(self, system, history, gen_conf): + response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) + if response: + return response + response = self._llm_model.chat(system, history, gen_conf) + if response.find("**ERROR**") >= 0: + raise Exception(response) + set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) + return response + + def _embedding_encode(self, txt): + response = get_embed_cache(self._embd_model.llm_name, txt) + if response: + return response + embds, _ = self._embd_model.encode([txt]) + if len(embds) < 1 or len(embds[0]) < 1: + raise Exception("Embedding error: ") + embds = embds[0] + set_embed_cache(self._embd_model.llm_name, txt, embds) + return embds + def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int): max_clusters = min(self._max_cluster, len(embeddings)) n_clusters = np.arange(1, max_clusters) @@ -57,7 +79,7 @@ def summarize(ck_idx, lock): texts = [chunks[i][0] for i in ck_idx] len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) - cnt = self._llm_model.chat("You're a helpful assistant.", + cnt = self._chat("You're a helpful assistant.", [{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}], {"temperature": 0.3, "max_tokens": self._max_token} @@ -67,9 +89,7 @@ def summarize(ck_idx, lock): logging.debug(f"SUM: {cnt}") embds, _ = self._embd_model.encode([cnt]) with lock: - if not len(embds[0]): - return - chunks.append((cnt, embds[0])) + chunks.append((cnt, self._embedding_encode(cnt))) except Exception as e: logging.exception("summarize got exception") return e diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 2b893cf5687..58c0e7a697d 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -19,6 +19,8 @@ import sys from api.utils.log_utils import initRootLogger +from graphrag.utils import get_llm_cache, set_llm_cache + CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NAME = "task_executor_" + CONSUMER_NO initRootLogger(CONSUMER_NAME) @@ -232,9 +234,6 @@ def build_chunks(task, progress_callback): if not d.get("image"): _ = d.pop("image", None) d["img_id"] = "" - d["page_num_int"] = [] - d["position_int"] = [] - d["top_int"] = [] docs.append(d) continue @@ -262,8 +261,16 @@ def build_chunks(task, progress_callback): progress_callback(msg="Start to generate keywords for every chunk ...") chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) for d in docs: - d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], - task["parser_config"]["auto_keywords"]).split(",") + cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", + {"topn": task["parser_config"]["auto_keywords"]}) + if not cached: + cached = keyword_extraction(chat_mdl, d["content_with_weight"], + task["parser_config"]["auto_keywords"]) + if cached: + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", + {"topn": task["parser_config"]["auto_keywords"]}) + + d["important_kwd"] = cached.split(",") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) progress_callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) @@ -272,7 +279,15 @@ def build_chunks(task, progress_callback): progress_callback(msg="Start to generate questions for every chunk ...") chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"]) for d in docs: - d["question_kwd"] = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]).split("\n") + cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", + {"topn": task["parser_config"]["auto_questions"]}) + if not cached: + cached = question_proposal(chat_mdl, d["content_with_weight"], task["parser_config"]["auto_questions"]) + if cached: + set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", + {"topn": task["parser_config"]["auto_questions"]}) + + d["question_kwd"] = cached.split("\n") d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))