Skip to content

Commit

Permalink
Update __version__
Browse files Browse the repository at this point in the history
  • Loading branch information
LarFii committed Dec 13, 2024
1 parent 9cac3b0 commit b7a2d33
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 39 deletions.
10 changes: 2 additions & 8 deletions examples/lightrag_zhipu_demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import asyncio
import os
import inspect
import logging

from dotenv import load_dotenv

from lightrag import LightRAG, QueryParam
from lightrag.llm import zhipu_complete, zhipu_embedding
Expand All @@ -21,7 +18,6 @@
raise Exception("Please set ZHIPU_API_KEY in your environment")



rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=zhipu_complete,
Expand All @@ -31,9 +27,7 @@
embedding_func=EmbeddingFunc(
embedding_dim=2048, # Zhipu embedding-3 dimension
max_token_size=8192,
func=lambda texts: zhipu_embedding(
texts
),
func=lambda texts: zhipu_embedding(texts),
),
)

Expand All @@ -58,4 +52,4 @@
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
)
2 changes: 1 addition & 1 deletion lightrag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam

__version__ = "1.0.5"
__version__ = "1.0.6"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
4 changes: 3 additions & 1 deletion lightrag/kg/milvus_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ async def wrapped_task(batch):
return result

embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)

embeddings = np.concatenate(embeddings_list)
Expand Down
48 changes: 20 additions & 28 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,11 @@ async def ollama_model_complete(
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs
**kwargs,
) -> str:
# dynamically load ZhipuAI
try:
Expand Down Expand Up @@ -640,13 +640,11 @@ async def zhipu_complete_if_cache(
logger.debug(f"System prompt: {system_prompt}")

# Remove unsupported kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}
kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}

response = client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
response = client.chat.completions.create(model=model, messages=messages, **kwargs)

return response.choices[0].message.content

Expand All @@ -663,13 +661,13 @@ async def zhipu_complete(
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""

# Combine with existing system prompt if any
Expand All @@ -683,15 +681,15 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
**kwargs,
)

# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
Expand All @@ -701,13 +699,15 @@ async def zhipu_complete(
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
pass

# If all parsing fails, log warning and return empty format
logger.warning(f"Failed to parse keyword extraction response: {response}")
logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
Expand All @@ -722,7 +722,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
**kwargs,
)


Expand All @@ -733,13 +733,9 @@ async def zhipu_complete(
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def zhipu_embedding(
texts: list[str],
model: str = "embedding-3",
api_key: str = None,
**kwargs
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
) -> np.ndarray:

# dynamically load ZhipuAI
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
Expand All @@ -758,11 +754,7 @@ async def zhipu_embedding(
embeddings = []
for text in texts:
try:
response = client.embeddings.create(
model=model,
input=[text],
**kwargs
)
response = client.embeddings.create(model=model, input=[text], **kwargs)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
Expand Down
4 changes: 3 additions & 1 deletion lightrag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ async def wrapped_task(batch):
return result

embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch")
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)

embeddings = np.concatenate(embeddings_list)
Expand Down

0 comments on commit b7a2d33

Please sign in to comment.