Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rerank model #969

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,15 @@ def retrieval_test():

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
Expand Down
5 changes: 5 additions & 0 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def set_dialog():
name = req.get("name", "New Dialog")
description = req.get("description", "A helpful Dialog")
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
if not rerank_id: req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
llm_setting = req.get("llm_setting", {})
Expand Down Expand Up @@ -83,6 +86,8 @@ def set_dialog():
"llm_setting": llm_setting,
"prompt_config": prompt_config,
"top_n": top_n,
"top_k": top_k,
"rerank_id": rerank_id,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight
}
Expand Down
16 changes: 13 additions & 3 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel


@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -64,6 +64,16 @@ def set_api_key():
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
elif llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)

if msg:
return get_data_error_result(retmsg=msg)
Expand Down Expand Up @@ -199,7 +209,7 @@ def list_app():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]

llm_set = set([m["llm_name"] for m in llms])
for o in objs:
Expand Down
8 changes: 5 additions & 3 deletions api/apps/user_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from api.utils.api_utils import server_error_response, validate_request
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
from api.db import UserTenantRole, LLMType, FileType
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
LLM_FACTORY, LLM_BASE_URL
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
API_KEY, \
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService
from api.settings import stat_logger
Expand Down Expand Up @@ -288,7 +289,8 @@ def user_register(user_id, user):
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
"img2txt_id": IMAGE2TEXT_MDL,
"rerank_id": RERANK_MDL
}
usr_tenant = {
"tenant_id": user_id,
Expand Down
1 change: 1 addition & 0 deletions api/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LLMType(StrEnum):
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'


class ChatStyle(StrEnum):
Expand Down
39 changes: 33 additions & 6 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
max_length=128,
null=False,
help_text="default image to text model ID")
rerank_id = CharField(
max_length=128,
null=False,
help_text="default rerank model ID")
parser_ids = CharField(
max_length=256,
null=False,
Expand Down Expand Up @@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
top_k = IntegerField(default=1024)
do_refer = CharField(
max_length=1,
null=False,
help_text="it needs to insert reference index into answer or not",
default="1")
rerank_id = CharField(
max_length=128,
null=False,
help_text="default rerank model ID")

kb_ids = JSONField(null=False, default=[])
status = CharField(
Expand Down Expand Up @@ -825,11 +834,29 @@ class Meta:


def migrate_db():
try:
with DB.transaction():
migrator = MySQLMigrator(DB)
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
)
except Exception as e:
pass
98 changes: 97 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,17 @@ def init_superuser():
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},
},{
"name": "Jina",
"logo": "",
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "BAAI",
"logo": "",
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
Expand Down Expand Up @@ -367,6 +377,13 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-reranker-base_v1",
"tags": "RE-RANK, 8K",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
# ------------------------ DeepSeek -----------------------
{
"fid": factory_infos[8]["name"],
Expand Down Expand Up @@ -440,6 +457,85 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Jina -----------------------
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-base-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-turbo-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-tiny-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-colbert-v1-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-en",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-de",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-es",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-code",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-zh",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ BAAI -----------------------
{
"fid": factory_infos[12]["name"],
"llm_name": "BAAI/bge-large-zh-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 1024,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[12]["name"],
"llm_name": "BAAI/bge-reranker-v2-m3",
"tags": "LLM,CHAT,",
"max_tokens": 16385,
"model_type": LLMType.RERANK.value
},
]
for info in factory_infos:
try:
Expand Down
7 changes: 5 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False)
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
Expand All @@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):

kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting

msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])
Expand Down
Loading