Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): support multi reranker & enhance the UI #73

Merged
merged 22 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,15 @@ class LLMConfigRequest(BaseModel):
# ollama-only properties
host: str = None
port: str = None

class EmbeddingConfigRequest(BaseModel):
llm_type: str
# The common parameters shared by OpenAI, Qianfan Wenxin,
# and OLLAMA platforms.
api_key: str

class RerankerConfigRequest(BaseModel):
reranker_model: str
reranker_type: str
api_key: str
api_base: str
68 changes: 57 additions & 11 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,43 @@
from fastapi import status, APIRouter

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import RAGRequest, GraphConfigRequest, LLMConfigRequest
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings


def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf):
def rag_http_api(
router: APIRouter,
rag_answer_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector)
result = rag_answer_func(
req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector
)
return {
key: value
for key, value in zip(["raw_llm", "vector_only", "graph_only", "graph_vector"], result)
for key, value in zip(
["raw_llm", "vector_only", "graph_only", "graph_vector"], result
)
if getattr(req, key)
}

@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
res = apply_graph_conf(
req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http"
)
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
Expand All @@ -45,22 +63,50 @@ def llm_config_api(req: LLMConfigRequest):

if req.llm_type == "openai":
res = apply_llm_conf(
req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http"
req.api_key,
req.api_base,
req.language_model,
req.max_tokens,
origin_call="http",
)
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
res = apply_llm_conf(
req.api_key,
req.secret_key,
req.language_model,
None,
origin_call="http",
)
else:
res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")
res = apply_llm_conf(
req.host, req.port, req.language_model, None, origin_call="http"
)
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
def embedding_config_api(req: LLMConfigRequest):
settings.embedding_type = req.llm_type

if req.llm_type == "openai":
res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
res = apply_embedding_conf(
req.api_key, req.api_base, req.language_model, origin_call="http"
)
elif req.llm_type == "qianfan_wenxin":
res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
res = apply_embedding_conf(
req.api_key, req.api_base, None, origin_call="http"
)
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
res = apply_embedding_conf(
req.host, req.port, req.language_model, origin_call="http"
)
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
def rerank_config_api(req: RerankerConfigRequest):
settings.reranker_type = req.reranker_type

if req.reranker_type == "cohere":
res = apply_reranker_conf(
jasinliu marked this conversation as resolved.
Show resolved Hide resolved
req.api_key, req.reranker_model, req.api_base, origin_call="http"
)
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
13 changes: 9 additions & 4 deletions hugegraph-llm/src/hugegraph_llm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,27 @@ class Config:
# env_path: Optional[str] = ".env"
llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai"
embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = "cohere"
# 1. OpenAI settings
openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_language_model: Optional[str] = "gpt-4o-mini"
openai_embedding_model: Optional[str] = "text-embedding-3-small"
openai_max_tokens: int = 4096
# 2. Ollama settings
# 2. Rerank settings
cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank")
reranker_api_key: Optional[str] = None
reranker_model: Optional[str] = "rerank-multilingual-v3.0"
# 3. Ollama settings
ollama_host: Optional[str] = "127.0.0.1"
ollama_port: Optional[int] = 11434
ollama_language_model: Optional[str] = None
ollama_embedding_model: Optional[str] = None
# 3. QianFan/WenXin settings
# 4. QianFan/WenXin settings
qianfan_api_key: Optional[str] = None
qianfan_secret_key: Optional[str] = None
qianfan_access_token: Optional[str] = None
# 3.1 url settings
# 4.1 url settings
qianfan_url_prefix: Optional[str] = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
)
Expand All @@ -59,7 +64,7 @@ class Config:
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
qianfan_embedding_model: Optional[str] = "embedding-v1"
# 4. ZhiPu(GLM) settings
# 5. ZhiPu(GLM) settings
zhipu_api_key: Optional[str] = None
zhipu_language_model: Optional[str] = "glm-4"
zhipu_embedding_model: Optional[str] = "embedding-2"
Expand Down
Loading
Loading