From dc2a45004a38b7a8e4003a11ff81cae93e6d78a6 Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sun, 8 Dec 2024 10:37:55 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E7=BC=93=E5=AD=98=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E8=BF=81=E7=A7=BB=E5=88=B0=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/utils.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 4c8d7996..0fcb437f 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union, List +from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET import numpy as np @@ -390,3 +390,71 @@ def dequantize_embedding( """Restore quantized embedding""" scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32) + +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + model: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "model": cache_data.model, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache}) From ccf44dc334531383af060821c54a48f9075bb9da Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sun, 8 Dec 2024 17:35:52 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat(cache):=20=E5=A2=9E=E5=8A=A0=20LLM=20?= =?UTF-8?q?=E7=9B=B8=E4=BC=BC=E6=80=A7=E6=A3=80=E6=9F=A5=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=E7=BC=93=E5=AD=98=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=20-=20=E5=9C=A8=20embedding=20=E7=BC=93=E5=AD=98=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E4=B8=AD=E6=B7=BB=E5=8A=A0=20use=5Fllm=5Fcheck=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=20-=20=E5=AE=9E=E7=8E=B0=20LLM=20=E7=9B=B8?= =?UTF-8?q?=E4=BC=BC=E6=80=A7=E6=A3=80=E6=9F=A5=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E4=BD=9C=E4=B8=BA=E7=BC=93=E5=AD=98=E5=91=BD=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E4=BA=8C=E6=AC=A1=E9=AA=8C=E8=AF=81-=20=E4=BC=98=E5=8C=96=20na?= =?UTF-8?q?ive=20=E6=A8=A1=E5=BC=8F=E7=9A=84=E7=BC=93=E5=AD=98=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=B5=81=E7=A8=8B=20-=20=E8=B0=83=E6=95=B4=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84=EF=BC=8C=E7=A7=BB?= =?UTF-8?q?=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84=20model=20?= =?UTF-8?q?=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 +- lightrag/lightrag.py | 9 +- lightrag/llm.py | 265 ++----------------------------------------- lightrag/operate.py | 59 ++++++++-- lightrag/prompt.py | 19 ++++ lightrag/utils.py | 55 ++++++++- 6 files changed, 138 insertions(+), 275 deletions(-) diff --git a/README.md b/README.md index a2cbb217..a1454792 100644 --- a/README.md +++ b/README.md @@ -596,11 +596,7 @@ if __name__ == "__main__": | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | -| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters: -- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached. -- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM. - -Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` | +| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | ## API Server Implementation diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0a44187e..0eb1b27e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -87,7 +87,11 @@ class LightRAG: ) # Default not to use embedding cache embedding_cache_config: dict = field( - default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95} + default_factory=lambda: { + "enabled": False, + "similarity_threshold": 0.95, + "use_llm_check": False, + } ) kv_storage: str = field(default="JsonKVStorage") vector_storage: str = field(default="NanoVectorDBStorage") @@ -174,7 +178,6 @@ def __post_init__(self): if self.enable_llm_cache else None ) - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) @@ -481,6 +484,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) elif param.mode == "naive": response = await naive_query( @@ -489,6 +493,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): self.text_chunks, param, asdict(self), + hashing_kv=self.llm_response_cache, ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/llm.py b/lightrag/llm.py index 63913c90..507753f4 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,8 +4,7 @@ import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union, Optional -from dataclasses import dataclass +from typing import List, Dict, Callable, Any, Union import aioboto3 import aiohttp import numpy as np @@ -27,13 +26,9 @@ ) from transformers import AutoTokenizer, AutoModelForCausalLM -from .base import BaseKVStorage from .utils import ( - compute_args_hash, wrap_embedding_func_with_attrs, locate_json_string_body_from_string, - quantize_embedding, - get_best_cached_response, ) import sys @@ -66,23 +61,13 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( model=model, messages=messages, **kwargs @@ -95,21 +80,6 @@ async def openai_complete_if_cache( if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - mode = kwargs.pop("mode", "default") - + kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - # Handle cache - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) content = response.choices[0].message.content - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return content @@ -210,7 +154,7 @@ async def bedrock_complete_if_cache( os.environ["AWS_SESSION_TOKEN"] = os.environ.get( "AWS_SESSION_TOKEN", aws_session_token ) - + kwargs.pop("hashing_kv", None) # Fix message history format messages = [] for history_message in history_messages: @@ -220,15 +164,6 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Initialize Converse API arguments args = {"modelId": model, "messages": messages} @@ -251,15 +186,6 @@ async def bedrock_complete_if_cache( args["inferenceConfig"][inference_params_map.get(param, param)] = ( kwargs.pop(param) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response # Call model via Converse API session = aioboto3.Session() @@ -269,21 +195,6 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response["output"]["message"]["content"][0]["text"], - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response["output"]["message"]["content"][0]["text"] @@ -315,22 +226,12 @@ async def hf_model_if_cache( ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - + kwargs.pop("hashing_kv", None) input_prompt = "" try: input_prompt = hf_tokenizer.apply_chat_template( @@ -375,21 +276,6 @@ async def hf_model_if_cache( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response_text, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response_text @@ -410,25 +296,14 @@ async def ollama_model_if_cache( # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) - + kwargs.pop("hashing_kv", None) ollama_client = ollama.AsyncClient(host=host, timeout=timeout) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) - - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: """ cannot cache stream response """ @@ -441,38 +316,7 @@ async def inner(): else: result = response["message"]["content"] # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) return result - result = response["message"]["content"] - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=result, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - - return result @lru_cache(maxsize=1) @@ -547,7 +391,7 @@ async def lmdeploy_model_if_cache( from lmdeploy import version_info, GenerationConfig except Exception: raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") - + kwargs.pop("hashing_kv", None) kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) tp = kwargs.pop("tp", 1) @@ -579,19 +423,9 @@ async def lmdeploy_model_if_cache( if system_prompt: messages.append({"role": "system", "content": system_prompt}) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - # Handle cache - mode = kwargs.pop("mode", "default") - args_hash = compute_args_hash(model, messages) - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, prompt, mode - ) - if cached_response is not None: - return cached_response - gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, @@ -607,22 +441,6 @@ async def lmdeploy_model_if_cache( session_id=1, ): response += res.response - - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) - return response @@ -1052,75 +870,6 @@ async def llm_model_func( return await next_model.gen_func(**args) -async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): - """Generic cache handling function""" - if hashing_kv is None: - return None, None, None, None - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - - quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - mode=mode, - ) - if best_cached_response is not None: - return best_cached_response, None, None, None - else: - # Use regular cache - mode_cache = await hashing_kv.get_by_id(mode) or {} - if args_hash in mode_cache: - return mode_cache[args_hash]["return"], None, None, None - - return None, quantized, min_val, max_val - - -@dataclass -class CacheData: - args_hash: str - content: str - model: str - prompt: str - quantized: Optional[np.ndarray] = None - min_val: Optional[float] = None - max_val: Optional[float] = None - mode: str = "default" - - -async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None: - return - - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - - mode_cache[cache_data.args_hash] = { - "return": cache_data.content, - "model": cache_data.model, - "embedding": cache_data.quantized.tobytes().hex() - if cache_data.quantized is not None - else None, - "embedding_shape": cache_data.quantized.shape - if cache_data.quantized is not None - else None, - "embedding_min": cache_data.min_val, - "embedding_max": cache_data.max_val, - "original_prompt": cache_data.prompt, - } - - await hashing_kv.upsert({cache_data.mode: mode_cache}) - - if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index acbdf072..feaec27d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -17,6 +17,10 @@ split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, + compute_args_hash, + handle_cache, + save_to_cache, + CacheData, ) from .base import ( BaseGraphStorage, @@ -452,8 +456,17 @@ async def kg_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ) -> str: - context = None + # Handle cache + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): examples = "\n".join( @@ -471,12 +484,9 @@ async def kg_query( return PROMPTS["fail_response"] # LLM generate keywords - use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func( - kw_prompt, keyword_extraction=True, mode=query_param.mode - ) + result = await use_model_func(kw_prompt, keyword_extraction=True) logger.info("kw_prompt result:") print(result) try: @@ -537,7 +547,6 @@ async def kg_query( query, system_prompt=sys_prompt, stream=query_param.stream, - mode=query_param.mode, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( @@ -550,6 +559,20 @@ async def kg_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response @@ -1013,8 +1036,17 @@ async def naive_query( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, global_config: dict, + hashing_kv: BaseKVStorage = None, ): + # Handle cache use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] @@ -1039,7 +1071,6 @@ async def naive_query( response = await use_model_func( query, system_prompt=sys_prompt, - mode=query_param.mode, ) if len(response) > len(sys_prompt): @@ -1054,4 +1085,18 @@ async def naive_query( .strip() ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response diff --git a/lightrag/prompt.py b/lightrag/prompt.py index d758397b..863d38dc 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -261,3 +261,22 @@ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ + +PROMPTS[ + "similarity_check" +] = """Please analyze the similarity between these two questions: + +Question 1: {original_prompt} +Question 2: {cached_prompt} + +Please evaluate: +1. Whether these two questions are semantically similar +2. Whether the answer to Question 2 can be used to answer Question 1 + +Please provide a similarity score between 0 and 1, where: +0: Completely unrelated or answer cannot be reused +1: Identical and answer can be directly reused +0.5: Partially related and answer needs modification to be used + +Return only a number between 0-1, without any additional content. +""" diff --git a/lightrag/utils.py b/lightrag/utils.py index 0fcb437f..32d5c87f 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -15,6 +15,8 @@ import numpy as np import tiktoken +from lightrag.prompt import PROMPTS + ENCODER = None logger = logging.getLogger("lightrag") @@ -314,6 +316,9 @@ async def get_best_cached_response( current_embedding, similarity_threshold=0.95, mode="default", + use_llm_check=False, + llm_func=None, + original_prompt=None, ) -> Union[str, None]: # Get mode-specific cache mode_cache = await hashing_kv.get_by_id(mode) @@ -348,6 +353,37 @@ async def get_best_cached_response( best_cache_id = cache_id if best_similarity > similarity_threshold: + # If LLM check is enabled and all required parameters are provided + if use_llm_check and llm_func and original_prompt and best_prompt: + compare_prompt = PROMPTS["similarity_check"].format( + original_prompt=original_prompt, cached_prompt=best_prompt + ) + + try: + llm_result = await llm_func(compare_prompt) + llm_result = llm_result.strip() + llm_similarity = float(llm_result) + + # Replace vector similarity with LLM similarity score + best_similarity = llm_similarity + if best_similarity < similarity_threshold: + log_data = { + "event": "llm_check_cache_rejected", + "original_question": original_prompt[:100] + "..." + if len(original_prompt) > 100 + else original_prompt, + "cached_question": best_prompt[:100] + "..." + if len(best_prompt) > 100 + else best_prompt, + "similarity_score": round(best_similarity, 4), + "threshold": similarity_threshold, + } + logger.info(json.dumps(log_data, ensure_ascii=False)) + return None + except Exception as e: # Catch all possible exceptions + logger.warning(f"LLM similarity check failed: {e}") + return None # Return None directly when LLM check fails + prompt_display = ( best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt ) @@ -391,21 +427,33 @@ def dequantize_embedding( scale = (max_val - min_val) / (2**bits - 1) return (quantized * scale + min_val).astype(np.float32) + async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): """Generic cache handling function""" if hashing_kv is None: return None, None, None, None + # For naive mode, only use simple cache matching + if mode == "naive": + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + return None, None, None, None + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + "embedding_cache_config", + {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}, ) is_embedding_cache_enabled = embedding_cache_config["enabled"] + use_llm_check = embedding_cache_config.get("use_llm_check", False) quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + llm_model_func = hashing_kv.global_config.get("llm_model_func") + current_embedding = await embedding_model_func([prompt]) quantized, min_val, max_val = quantize_embedding(current_embedding[0]) best_cached_response = await get_best_cached_response( @@ -413,6 +461,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): current_embedding[0], similarity_threshold=embedding_cache_config["similarity_threshold"], mode=mode, + use_llm_check=use_llm_check, + llm_func=llm_model_func if use_llm_check else None, + original_prompt=prompt if use_llm_check else None, ) if best_cached_response is not None: return best_cached_response, None, None, None @@ -429,7 +480,6 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): class CacheData: args_hash: str content: str - model: str prompt: str quantized: Optional[np.ndarray] = None min_val: Optional[float] = None @@ -445,7 +495,6 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): mode_cache[cache_data.args_hash] = { "return": cache_data.content, - "model": cache_data.model, "embedding": cache_data.quantized.tobytes().hex() if cache_data.quantized is not None else None, From 39c2cb11f305efe1b6e6fd3f9ea07aa268eafa56 Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sun, 8 Dec 2024 17:37:58 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=B8=85=E7=90=86=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 507753f4..9c94fe6c 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -315,7 +315,6 @@ async def inner(): return inner() else: result = response["message"]["content"] - # Save to cache return result From 779ed604d881d262196148893c3765af9c7e19f1 Mon Sep 17 00:00:00 2001 From: Magic_yuan <317617749@qq.com> Date: Sun, 8 Dec 2024 17:38:49 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E6=B8=85=E7=90=86=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 9c94fe6c..b2bb99b7 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -314,8 +314,7 @@ async def inner(): return inner() else: - result = response["message"]["content"] - return result + return response["message"]["content"] @lru_cache(maxsize=1)