Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): support multi reranker & enhance the UI #73

Merged
merged 22 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ class LLMConfigRequest(BaseModel):
# ollama-only properties
host: str = None
port: str = None

class RerankerConfigRequest(BaseModel):
reranker_model: str
reranker_type: str
api_key: str
cohere_base_url: Optional[str] = None
27 changes: 22 additions & 5 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
from fastapi import status, APIRouter

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import RAGRequest, GraphConfigRequest, LLMConfigRequest
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings


def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf):
def rag_http_api(
router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector)
Expand All @@ -44,9 +51,7 @@ def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type

if req.llm_type == "openai":
res = apply_llm_conf(
req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http"
)
res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
Expand All @@ -64,3 +69,15 @@ def embedding_config_api(req: LLMConfigRequest):
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
def rerank_config_api(req: RerankerConfigRequest):
settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
17 changes: 10 additions & 7 deletions hugegraph-llm/src/hugegraph_llm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,34 @@ class Config:
# env_path: Optional[str] = ".env"
llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai"
embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = "cohere"
# 1. OpenAI settings
openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_language_model: Optional[str] = "gpt-4o-mini"
openai_embedding_model: Optional[str] = "text-embedding-3-small"
openai_max_tokens: int = 4096
# 2. Ollama settings
# 2. Rerank settings
cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank")
reranker_api_key: Optional[str] = None
reranker_model: Optional[str] = "rerank-multilingual-v3.0"
# 3. Ollama settings
ollama_host: Optional[str] = "127.0.0.1"
ollama_port: Optional[int] = 11434
ollama_language_model: Optional[str] = None
ollama_embedding_model: Optional[str] = None
# 3. QianFan/WenXin settings
# 4. QianFan/WenXin settings
qianfan_api_key: Optional[str] = None
qianfan_secret_key: Optional[str] = None
qianfan_access_token: Optional[str] = None
# 3.1 url settings
qianfan_url_prefix: Optional[str] = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
)
# 4.1 URL settings
qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/"
qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K"
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
qianfan_embedding_model: Optional[str] = "embedding-v1"
# 4. ZhiPu(GLM) settings
# 5. ZhiPu(GLM) settings
zhipu_api_key: Optional[str] = None
zhipu_language_model: Optional[str] = "glm-4"
zhipu_embedding_model: Optional[str] = "embedding-2"
Expand Down
Loading
Loading