Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
partoneplay committed Dec 9, 2024
2 parents a8e09ba + d8edc91 commit 3adafa8
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 159 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,7 @@ if __name__ == "__main__":
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.

Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:<br>- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.<br>- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.<br>- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |

## API Server Implementation

Expand Down
9 changes: 8 additions & 1 deletion examples/graph_visual_with_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
# Convert NetworkX graph to Pyvis network
net.from_nx(G)

# Add colors to nodes
# Add colors and title to nodes
for node in net.nodes:
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
if "description" in node:
node["title"] = node["description"]

# Add title to edges
for edge in net.edges:
if "description" in edge:
edge["title"] = edge["description"]

# Save and display the network
net.show("knowledge_graph.html")
114 changes: 114 additions & 0 deletions examples/lightrag_jinaai_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
from lightrag.llm import jina_embedding, openai_complete_if_cache
import os
import asyncio


async def embedding_func(texts: list[str]) -> np.ndarray:
return await jina_embedding(texts, api_key="YourJinaAPIKey")


WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)


async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs,
)


rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024, max_token_size=8192, func=embedding_func
),
)


async def lightraginsert(file_path, semaphore):
async with semaphore:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
with open(file_path, "r", encoding="gbk") as f:
content = f.read()
await rag.ainsert(content)


async def process_files(directory, concurrency_limit):
semaphore = asyncio.Semaphore(concurrency_limit)
tasks = []
for root, dirs, files in os.walk(directory):
for f in files:
file_path = os.path.join(root, f)
if f.startswith("."):
continue
tasks.append(lightraginsert(file_path, semaphore))
await asyncio.gather(*tasks)


async def main():
try:
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=embedding_func,
),
)

asyncio.run(process_files(WORKING_DIR, concurrency_limit=4))

# Perform naive search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)

# Perform local search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)

# Perform global search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)

# Perform hybrid search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
except Exception as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
asyncio.run(main())
9 changes: 7 additions & 2 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class LightRAG:
)
# Default not to use embedding cache
embedding_cache_config: dict = field(
default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
default_factory=lambda: {
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
kv_storage: str = field(default="JsonKVStorage")
vector_storage: str = field(default="NanoVectorDBStorage")
Expand Down Expand Up @@ -174,7 +178,6 @@ def __post_init__(self):
if self.enable_llm_cache
else None
)

self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
Expand Down Expand Up @@ -481,6 +484,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache,
)
elif param.mode == "naive":
response = await naive_query(
Expand All @@ -489,6 +493,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
self.text_chunks,
param,
asdict(self),
hashing_kv=self.llm_response_cache,
)
else:
raise ValueError(f"Unknown mode {param.mode}")
Expand Down
Loading

0 comments on commit 3adafa8

Please sign in to comment.