From 584258078f542fc2c4c4ff1ade8cb5342d80989e Mon Sep 17 00:00:00 2001
From: yuanxiaobin <xiaobin.yuan@genscript.com>
Date: Fri, 6 Dec 2024 14:29:16 +0800
Subject: [PATCH 1/5] =?UTF-8?q?=E9=87=8D=E6=9E=84=E7=BC=93=E5=AD=98?=
 =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

- 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache
- 使用 CacheData 类统一缓存数据结构
- 优化嵌入式缓存和常规缓存的处理流程
- 添加模式参数以支持不同查询模式的缓存策略
- 重构 get_best_cached_response 函数,提高缓存查询效率
---
 lightrag/llm.py     | 496 ++++++++++++++++++++------------------------
 lightrag/operate.py |   6 +-
 lightrag/utils.py   |  84 ++++----
 3 files changed, 277 insertions(+), 309 deletions(-)

diff --git a/lightrag/llm.py b/lightrag/llm.py
index fef8c9a31..89d74a5b7 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -4,7 +4,8 @@
 import os
 import struct
 from functools import lru_cache
-from typing import List, Dict, Callable, Any
+from typing import List, Dict, Callable, Any, Optional
+from dataclasses import dataclass
 
 import aioboto3
 import aiohttp
@@ -59,39 +60,21 @@ async def openai_complete_if_cache(
     openai_async_client = (
         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
     )
-    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+
     messages = []
     if system_prompt:
         messages.append({"role": "system", "content": system_prompt})
     messages.extend(history_messages)
     messages.append({"role": "user", "content": prompt})
 
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        kwargs.get("hashing_kv"), args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     if "response_format" in kwargs:
         response = await openai_async_client.beta.chat.completions.parse(
@@ -105,24 +88,21 @@ async def openai_complete_if_cache(
     if r"\u" in content:
         content = content.encode("utf-8").decode("unicode_escape")
 
-    if hashing_kv is not None:
-        await hashing_kv.upsert(
-            {
-                args_hash: {
-                    "return": content,
-                    "model": model,
-                    "embedding": quantized.tobytes().hex()
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_shape": quantized.shape
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_min": min_val if is_embedding_cache_enabled else None,
-                    "embedding_max": max_val if is_embedding_cache_enabled else None,
-                    "original_prompt": prompt,
-                }
-            }
-        )
+    # Save to cache
+    await save_to_cache(
+        kwargs.get("hashing_kv"),
+        CacheData(
+            args_hash=args_hash,
+            content=content,
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
+
     return content
 
 
@@ -155,6 +135,8 @@ async def azure_openai_complete_if_cache(
     )
 
     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+    mode = kwargs.pop("mode", "default")
+
     messages = []
     if system_prompt:
         messages.append({"role": "system", "content": system_prompt})
@@ -162,56 +144,35 @@ async def azure_openai_complete_if_cache(
     if prompt is not None:
         messages.append({"role": "user", "content": prompt})
 
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        hashing_kv, args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     response = await openai_async_client.chat.completions.create(
         model=model, messages=messages, **kwargs
     )
+    content = response.choices[0].message.content
+
+    # Save to cache
+    await save_to_cache(
+        hashing_kv,
+        CacheData(
+            args_hash=args_hash,
+            content=content,
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
 
-    if hashing_kv is not None:
-        await hashing_kv.upsert(
-            {
-                args_hash: {
-                    "return": response.choices[0].message.content,
-                    "model": model,
-                    "embedding": quantized.tobytes().hex()
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_shape": quantized.shape
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_min": min_val if is_embedding_cache_enabled else None,
-                    "embedding_max": max_val if is_embedding_cache_enabled else None,
-                    "original_prompt": prompt,
-                }
-            }
-        )
-    return response.choices[0].message.content
+    return content
 
 
 class BedrockError(Exception):
@@ -253,6 +214,15 @@ async def bedrock_complete_if_cache(
     # Add user prompt
     messages.append({"role": "user", "content": [{"text": prompt}]})
 
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        kwargs.get("hashing_kv"), args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
+
     # Initialize Converse API arguments
     args = {"modelId": model, "messages": messages}
 
@@ -275,33 +245,14 @@ async def bedrock_complete_if_cache(
                 kwargs.pop(param)
             )
 
-    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        kwargs.get("hashing_kv"), args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     # Call model via Converse API
     session = aioboto3.Session()
@@ -311,30 +262,22 @@ async def bedrock_complete_if_cache(
         except Exception as e:
             raise BedrockError(e)
 
-        if hashing_kv is not None:
-            await hashing_kv.upsert(
-                {
-                    args_hash: {
-                        "return": response["output"]["message"]["content"][0]["text"],
-                        "model": model,
-                        "embedding": quantized.tobytes().hex()
-                        if is_embedding_cache_enabled
-                        else None,
-                        "embedding_shape": quantized.shape
-                        if is_embedding_cache_enabled
-                        else None,
-                        "embedding_min": min_val
-                        if is_embedding_cache_enabled
-                        else None,
-                        "embedding_max": max_val
-                        if is_embedding_cache_enabled
-                        else None,
-                        "original_prompt": prompt,
-                    }
-                }
-            )
+    # Save to cache
+    await save_to_cache(
+        kwargs.get("hashing_kv"),
+        CacheData(
+            args_hash=args_hash,
+            content=response["output"]["message"]["content"][0]["text"],
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
 
-        return response["output"]["message"]["content"][0]["text"]
+    return response["output"]["message"]["content"][0]["text"]
 
 
 @lru_cache(maxsize=1)
@@ -372,32 +315,14 @@ async def hf_model_if_cache(
     messages.extend(history_messages)
     messages.append({"role": "user", "content": prompt})
 
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        hashing_kv, args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     input_prompt = ""
     try:
@@ -442,24 +367,22 @@ async def hf_model_if_cache(
     response_text = hf_tokenizer.decode(
         output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
     )
-    if hashing_kv is not None:
-        await hashing_kv.upsert(
-            {
-                args_hash: {
-                    "return": response_text,
-                    "model": model,
-                    "embedding": quantized.tobytes().hex()
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_shape": quantized.shape
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_min": min_val if is_embedding_cache_enabled else None,
-                    "embedding_max": max_val if is_embedding_cache_enabled else None,
-                    "original_prompt": prompt,
-                }
-            }
-        )
+
+    # Save to cache
+    await save_to_cache(
+        hashing_kv,
+        CacheData(
+            args_hash=args_hash,
+            content=response_text,
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
+
     return response_text
 
 
@@ -489,55 +412,34 @@ async def ollama_model_if_cache(
     messages.extend(history_messages)
     messages.append({"role": "user", "content": prompt})
 
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        hashing_kv, args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     response = await ollama_client.chat(model=model, messages=messages, **kwargs)
 
     result = response["message"]["content"]
 
-    if hashing_kv is not None:
-        await hashing_kv.upsert(
-            {
-                args_hash: {
-                    "return": result,
-                    "model": model,
-                    "embedding": quantized.tobytes().hex()
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_shape": quantized.shape
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_min": min_val if is_embedding_cache_enabled else None,
-                    "embedding_max": max_val if is_embedding_cache_enabled else None,
-                    "original_prompt": prompt,
-                }
-            }
-        )
+    # Save to cache
+    await save_to_cache(
+        hashing_kv,
+        CacheData(
+            args_hash=args_hash,
+            content=result,
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
+
     return result
 
 
@@ -649,32 +551,14 @@ async def lmdeploy_model_if_cache(
     messages.extend(history_messages)
     messages.append({"role": "user", "content": prompt})
 
-    if hashing_kv is not None:
-        # Calculate args_hash only when using cache
-        args_hash = compute_args_hash(model, messages)
-
-        # Get embedding cache configuration
-        embedding_cache_config = hashing_kv.global_config.get(
-            "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
-        )
-        is_embedding_cache_enabled = embedding_cache_config["enabled"]
-        if is_embedding_cache_enabled:
-            # Use embedding cache
-            embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
-            current_embedding = await embedding_model_func([prompt])
-            quantized, min_val, max_val = quantize_embedding(current_embedding[0])
-            best_cached_response = await get_best_cached_response(
-                hashing_kv,
-                current_embedding[0],
-                similarity_threshold=embedding_cache_config["similarity_threshold"],
-            )
-            if best_cached_response is not None:
-                return best_cached_response
-        else:
-            # Use regular cache
-            if_cache_return = await hashing_kv.get_by_id(args_hash)
-            if if_cache_return is not None:
-                return if_cache_return["return"]
+    # Handle cache
+    mode = kwargs.pop("mode", "default")
+    args_hash = compute_args_hash(model, messages)
+    cached_response, quantized, min_val, max_val = await handle_cache(
+        hashing_kv, args_hash, prompt, mode
+    )
+    if cached_response is not None:
+        return cached_response
 
     gen_config = GenerationConfig(
         skip_special_tokens=skip_special_tokens,
@@ -692,24 +576,21 @@ async def lmdeploy_model_if_cache(
     ):
         response += res.response
 
-    if hashing_kv is not None:
-        await hashing_kv.upsert(
-            {
-                args_hash: {
-                    "return": response,
-                    "model": model,
-                    "embedding": quantized.tobytes().hex()
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_shape": quantized.shape
-                    if is_embedding_cache_enabled
-                    else None,
-                    "embedding_min": min_val if is_embedding_cache_enabled else None,
-                    "embedding_max": max_val if is_embedding_cache_enabled else None,
-                    "original_prompt": prompt,
-                }
-            }
-        )
+    # Save to cache
+    await save_to_cache(
+        hashing_kv,
+        CacheData(
+            args_hash=args_hash,
+            content=response,
+            model=model,
+            prompt=prompt,
+            quantized=quantized,
+            min_val=min_val,
+            max_val=max_val,
+            mode=mode,
+        ),
+    )
+
     return response
 
 
@@ -1139,6 +1020,75 @@ async def llm_model_func(
         return await next_model.gen_func(**args)
 
 
+async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
+    """Generic cache handling function"""
+    if hashing_kv is None:
+        return None, None, None, None
+
+    # Get embedding cache configuration
+    embedding_cache_config = hashing_kv.global_config.get(
+        "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
+    )
+    is_embedding_cache_enabled = embedding_cache_config["enabled"]
+
+    quantized = min_val = max_val = None
+    if is_embedding_cache_enabled:
+        # Use embedding cache
+        embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
+        current_embedding = await embedding_model_func([prompt])
+        quantized, min_val, max_val = quantize_embedding(current_embedding[0])
+        best_cached_response = await get_best_cached_response(
+            hashing_kv,
+            current_embedding[0],
+            similarity_threshold=embedding_cache_config["similarity_threshold"],
+            mode=mode,
+        )
+        if best_cached_response is not None:
+            return best_cached_response, None, None, None
+    else:
+        # Use regular cache
+        mode_cache = await hashing_kv.get_by_id(mode) or {}
+        if args_hash in mode_cache:
+            return mode_cache[args_hash]["return"], None, None, None
+
+    return None, quantized, min_val, max_val
+
+
+@dataclass
+class CacheData:
+    args_hash: str
+    content: str
+    model: str
+    prompt: str
+    quantized: Optional[np.ndarray] = None
+    min_val: Optional[float] = None
+    max_val: Optional[float] = None
+    mode: str = "default"
+
+
+async def save_to_cache(hashing_kv, cache_data: CacheData):
+    if hashing_kv is None:
+        return
+
+    mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
+
+    mode_cache[cache_data.args_hash] = {
+        "return": cache_data.content,
+        "model": cache_data.model,
+        "embedding": cache_data.quantized.tobytes().hex()
+        if cache_data.quantized is not None
+        else None,
+        "embedding_shape": cache_data.quantized.shape
+        if cache_data.quantized is not None
+        else None,
+        "embedding_min": cache_data.min_val,
+        "embedding_max": cache_data.max_val,
+        "original_prompt": cache_data.prompt,
+    }
+
+    await hashing_kv.upsert({cache_data.mode: mode_cache})
+
+
 if __name__ == "__main__":
     import asyncio
 
diff --git a/lightrag/operate.py b/lightrag/operate.py
index a846cfc58..5b911d342 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -474,7 +474,9 @@ async def kg_query(
     use_model_func = global_config["llm_model_func"]
     kw_prompt_temp = PROMPTS["keywords_extraction"]
     kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
-    result = await use_model_func(kw_prompt, keyword_extraction=True)
+    result = await use_model_func(
+        kw_prompt, keyword_extraction=True, mode=query_param.mode
+    )
     logger.info("kw_prompt result:")
     print(result)
     try:
@@ -534,6 +536,7 @@ async def kg_query(
     response = await use_model_func(
         query,
         system_prompt=sys_prompt,
+        mode=query_param.mode,
     )
     if len(response) > len(sys_prompt):
         response = (
@@ -1035,6 +1038,7 @@ async def naive_query(
     response = await use_model_func(
         query,
         system_prompt=sys_prompt,
+        mode=query_param.mode,
     )
 
     if len(response) > len(sys_prompt):
diff --git a/lightrag/utils.py b/lightrag/utils.py
index d080ee03d..70ec4341c 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll):
 
 
 async def get_best_cached_response(
-    hashing_kv, current_embedding, similarity_threshold=0.95
-):
-    """Get the cached response with the highest similarity"""
-    try:
-        # Get all keys
-        all_keys = await hashing_kv.all_keys()
-        max_similarity = 0
-        best_cached_response = None
-
-        # Get cached data one by one
-        for key in all_keys:
-            cache_data = await hashing_kv.get_by_id(key)
-            if cache_data is None or "embedding" not in cache_data:
-                continue
-
-            # Convert cached embedding list to ndarray
-            cached_quantized = np.frombuffer(
-                bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
-            ).reshape(cache_data["embedding_shape"])
-            cached_embedding = dequantize_embedding(
-                cached_quantized,
-                cache_data["embedding_min"],
-                cache_data["embedding_max"],
-            )
-
-            similarity = cosine_similarity(current_embedding, cached_embedding)
-            if similarity > max_similarity:
-                max_similarity = similarity
-                best_cached_response = cache_data["return"]
-
-        if max_similarity > similarity_threshold:
-            return best_cached_response
+    hashing_kv,
+    current_embedding,
+    similarity_threshold=0.95,
+    mode="default",
+) -> Union[str, None]:
+    # Get mode-specific cache
+    mode_cache = await hashing_kv.get_by_id(mode)
+    if not mode_cache:
         return None
 
-    except Exception as e:
-        logger.warning(f"Error in get_best_cached_response: {e}")
-        return None
+    best_similarity = -1
+    best_response = None
+    best_prompt = None
+    best_cache_id = None
+
+    # Only iterate through cache entries for this mode
+    for cache_id, cache_data in mode_cache.items():
+        if cache_data["embedding"] is None:
+            continue
+
+        # Convert cached embedding list to ndarray
+        cached_quantized = np.frombuffer(
+            bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
+        ).reshape(cache_data["embedding_shape"])
+        cached_embedding = dequantize_embedding(
+            cached_quantized,
+            cache_data["embedding_min"],
+            cache_data["embedding_max"],
+        )
+
+        similarity = cosine_similarity(current_embedding, cached_embedding)
+        if similarity > best_similarity:
+            best_similarity = similarity
+            best_response = cache_data["return"]
+            best_prompt = cache_data["original_prompt"]
+            best_cache_id = cache_id
+
+    if best_similarity > similarity_threshold:
+        prompt_display = (
+            best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
+        )
+        log_data = {
+            "event": "cache_hit",
+            "mode": mode,
+            "similarity": round(best_similarity, 4),
+            "cache_id": best_cache_id,
+            "original_prompt": prompt_display,
+        }
+        logger.info(json.dumps(log_data))
+        return best_response
+    return None
 
 
 def cosine_similarity(v1, v2):

From 558068f61171c4e691e1bb12808c7d1231ddc628 Mon Sep 17 00:00:00 2001
From: yuanxiaobin <xiaobin.yuan@genscript.com>
Date: Fri, 6 Dec 2024 14:32:41 +0800
Subject: [PATCH 2/5] =?UTF-8?q?fix(utils):=20=E4=BF=AE=E5=A4=8D=20JSON=20?=
 =?UTF-8?q?=E6=97=A5=E5=BF=97=E7=BC=96=E7=A0=81=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

- 在 json.dumps 中添加 ensure_ascii=False 参数,以支持非 ASCII 字符编码
-这个修改确保了包含中文等非 ASCII 字符的日志信息能够正确处理和显示
---
 lightrag/utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/lightrag/utils.py b/lightrag/utils.py
index 70ec4341c..4c8d7996a 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -358,7 +358,7 @@ async def get_best_cached_response(
             "cache_id": best_cache_id,
             "original_prompt": prompt_display,
         }
-        logger.info(json.dumps(log_data))
+        logger.info(json.dumps(log_data, ensure_ascii=False))
         return best_response
     return None
 

From 633fb55b5b888aaf38fee8d756622a3f1d00a370 Mon Sep 17 00:00:00 2001
From: yuanxiaobin <xiaobin.yuan@genscript.com>
Date: Fri, 6 Dec 2024 15:09:50 +0800
Subject: [PATCH 3/5] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=86=B2=E7=AA=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lightrag/llm.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/lightrag/llm.py b/lightrag/llm.py
index dda656304..d147e416d 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -4,8 +4,8 @@
 import os
 import struct
 from functools import lru_cache
-from typing import List, Dict, Callable, Any, Union
-
+from typing import List, Dict, Callable, Any, Union, Optional
+from dataclasses import dataclass
 import aioboto3
 import aiohttp
 import numpy as np

From a1c4a036fd2187c5b463ba2c24815c39a918a835 Mon Sep 17 00:00:00 2001
From: yuanxiaobin <xiaobin.yuan@genscript.com>
Date: Fri, 6 Dec 2024 15:23:18 +0800
Subject: [PATCH 4/5] =?UTF-8?q?=E7=A7=BB=E9=99=A4kwargs=E4=B8=AD=E7=9A=84h?=
 =?UTF-8?q?ashing=5Fkv=E5=8F=82=E6=95=B0=E5=8F=96=E4=B8=BA=E5=8F=98?=
 =?UTF-8?q?=E9=87=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lightrag/llm.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/lightrag/llm.py b/lightrag/llm.py
index d147e416d..09e9fd741 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -73,11 +73,12 @@ async def openai_complete_if_cache(
     messages.extend(history_messages)
     messages.append({"role": "user", "content": prompt})
 
+    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
     # Handle cache
     mode = kwargs.pop("mode", "default")
     args_hash = compute_args_hash(model, messages)
     cached_response, quantized, min_val, max_val = await handle_cache(
-        kwargs.get("hashing_kv"), args_hash, prompt, mode
+        hashing_kv, args_hash, prompt, mode
     )
     if cached_response is not None:
         return cached_response
@@ -219,12 +220,12 @@ async def bedrock_complete_if_cache(
 
     # Add user prompt
     messages.append({"role": "user", "content": [{"text": prompt}]})
-
+    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
     # Handle cache
     mode = kwargs.pop("mode", "default")
     args_hash = compute_args_hash(model, messages)
     cached_response, quantized, min_val, max_val = await handle_cache(
-        kwargs.get("hashing_kv"), args_hash, prompt, mode
+        hashing_kv, args_hash, prompt, mode
     )
     if cached_response is not None:
         return cached_response
@@ -250,12 +251,12 @@ async def bedrock_complete_if_cache(
             args["inferenceConfig"][inference_params_map.get(param, param)] = (
                 kwargs.pop(param)
             )
-
+    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
     # Handle cache
     mode = kwargs.pop("mode", "default")
     args_hash = compute_args_hash(model, messages)
     cached_response, quantized, min_val, max_val = await handle_cache(
-        kwargs.get("hashing_kv"), args_hash, prompt, mode
+        hashing_kv, args_hash, prompt, mode
     )
     if cached_response is not None:
         return cached_response

From 6a010abb625d82af21cba2374a717f86f56b5c09 Mon Sep 17 00:00:00 2001
From: yuanxiaobin <xiaobin.yuan@genscript.com>
Date: Fri, 6 Dec 2024 15:35:09 +0800
Subject: [PATCH 5/5] =?UTF-8?q?=E7=A7=BB=E9=99=A4kwargs=E4=B8=AD=E7=9A=84h?=
 =?UTF-8?q?ashing=5Fkv=E5=8F=82=E6=95=B0=E5=8F=96=E4=B8=BA=E5=8F=98?=
 =?UTF-8?q?=E9=87=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lightrag/llm.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/lightrag/llm.py b/lightrag/llm.py
index 09e9fd741..63913c902 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -97,7 +97,7 @@ async def openai_complete_if_cache(
 
     # Save to cache
     await save_to_cache(
-        kwargs.get("hashing_kv"),
+        hashing_kv,
         CacheData(
             args_hash=args_hash,
             content=content,
@@ -271,7 +271,7 @@ async def bedrock_complete_if_cache(
 
     # Save to cache
     await save_to_cache(
-        kwargs.get("hashing_kv"),
+        hashing_kv,
         CacheData(
             args_hash=args_hash,
             content=response["output"]["message"]["content"][0]["text"],