From 614defec213447b1c5596805b9cd6094935c2c81 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Wed, 29 May 2024 16:50:02 +0800 Subject: [PATCH] add rerank model (#969) ### What problem does this PR solve? feat: add rerank models to the project #724 #162 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/chunk_app.py | 11 ++- api/apps/dialog_app.py | 5 ++ api/apps/llm_app.py | 16 ++++- api/apps/user_app.py | 8 ++- api/db/__init__.py | 1 + api/db/db_models.py | 39 +++++++++-- api/db/init_data.py | 98 +++++++++++++++++++++++++- api/db/services/dialog_service.py | 7 +- api/db/services/llm_service.py | 38 ++++++++-- api/db/services/user_service.py | 1 + api/settings.py | 16 ++++- rag/llm/__init__.py | 13 +++- rag/llm/embedding_model.py | 80 ++++++++++++++------- rag/llm/rerank_model.py | 113 ++++++++++++++++++++++++++++++ rag/nlp/query.py | 11 +-- rag/nlp/rag_tokenizer.py | 9 ++- rag/nlp/search.py | 35 +++++++-- 17 files changed, 437 insertions(+), 64 deletions(-) create mode 100644 rag/llm/rerank_model.py diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 539d5c0eeb7..0ece95933fe 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -257,8 +257,15 @@ def retrieval_test(): embd_mdl = TenantLLMService.model_instance( kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) - ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, - vector_similarity_weight, top, doc_ids) + + rerank_mdl = None + if req.get("rerank_id"): + rerank_mdl = TenantLLMService.model_instance( + kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + + ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) for c in ranks["chunks"]: if "vector" in c: del c["vector"] diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 6423a4614d8..2969c12465d 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -33,6 +33,9 @@ def set_dialog(): name = req.get("name", "New Dialog") description = req.get("description", "A helpful Dialog") top_n = req.get("top_n", 6) + top_k = req.get("top_k", 1024) + rerank_id = req.get("rerank_id", "") + if not rerank_id: req["rerank_id"] = "" similarity_threshold = req.get("similarity_threshold", 0.1) vector_similarity_weight = req.get("vector_similarity_weight", 0.3) llm_setting = req.get("llm_setting", {}) @@ -83,6 +86,8 @@ def set_dialog(): "llm_setting": llm_setting, "prompt_config": prompt_config, "top_n": top_n, + "top_k": top_k, + "rerank_id": rerank_id, "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight } diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index b956017060d..36fa5c3ccb9 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -20,7 +20,7 @@ from api.db import StatusEnum, LLMType from api.db.db_models import TenantLLM from api.utils.api_utils import get_json_result -from rag.llm import EmbeddingModel, ChatModel +from rag.llm import EmbeddingModel, ChatModel, RerankModel @manager.route('/factories', methods=['GET']) @@ -28,7 +28,7 @@ def factories(): try: fac = LLMFactoriesService.get_all() - return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) + return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]]) except Exception as e: return server_error_response(e) @@ -64,6 +64,16 @@ def set_api_key(): except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( e) + elif llm.model_type == LLMType.RERANK: + mdl = RerankModel[factory]( + req["api_key"], llm.llm_name, base_url=req.get("base_url")) + try: + m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) + if len(arr[0]) == 0 or tc == 0: + raise Exception("Fail") + except Exception as e: + msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( + e) if msg: return get_data_error_result(retmsg=msg) @@ -199,7 +209,7 @@ def list_app(): llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] for m in llms: - m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] + m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"] llm_set = set([m["llm_name"] for m in llms]) for o in objs: diff --git a/api/apps/user_app.py b/api/apps/user_app.py index e5534a10d4e..48e02612d0a 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -26,8 +26,9 @@ from api.utils.api_utils import server_error_response, validate_request from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format from api.db import UserTenantRole, LLMType, FileType -from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ - LLM_FACTORY, LLM_BASE_URL +from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ + API_KEY, \ + LLM_FACTORY, LLM_BASE_URL, RERANK_MDL from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.file_service import FileService from api.settings import stat_logger @@ -288,7 +289,8 @@ def user_register(user_id, user): "embd_id": EMBEDDING_MDL, "asr_id": ASR_MDL, "parser_ids": PARSERS, - "img2txt_id": IMAGE2TEXT_MDL + "img2txt_id": IMAGE2TEXT_MDL, + "rerank_id": RERANK_MDL } usr_tenant = { "tenant_id": user_id, diff --git a/api/db/__init__.py b/api/db/__init__.py index 06127547406..f4c96f3a805 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -54,6 +54,7 @@ class LLMType(StrEnum): EMBEDDING = 'embedding' SPEECH2TEXT = 'speech2text' IMAGE2TEXT = 'image2text' + RERANK = 'rerank' class ChatStyle(StrEnum): diff --git a/api/db/db_models.py b/api/db/db_models.py index 7287a7df79f..93e8115ce43 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -437,6 +437,10 @@ class Tenant(DataBaseModel): max_length=128, null=False, help_text="default image to text model ID") + rerank_id = CharField( + max_length=128, + null=False, + help_text="default rerank model ID") parser_ids = CharField( max_length=256, null=False, @@ -771,11 +775,16 @@ class Dialog(DataBaseModel): similarity_threshold = FloatField(default=0.2) vector_similarity_weight = FloatField(default=0.3) top_n = IntegerField(default=6) + top_k = IntegerField(default=1024) do_refer = CharField( max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") + rerank_id = CharField( + max_length=128, + null=False, + help_text="default rerank model ID") kb_ids = JSONField(null=False, default=[]) status = CharField( @@ -825,11 +834,29 @@ class Meta: def migrate_db(): - try: with DB.transaction(): migrator = MySQLMigrator(DB) - migrate( - migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) - ) - except Exception as e: - pass + try: + migrate( + migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) + ) + except Exception as e: + pass + try: + migrate( + migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")) + ) + except Exception as e: + pass + try: + migrate( + migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID")) + ) + except Exception as e: + pass + try: + migrate( + migrator.add_column('dialog', 'top_k', IntegerField(default=1024)) + ) + except Exception as e: + pass diff --git a/api/db/init_data.py b/api/db/init_data.py index 1a4706f255d..01b019ebbbe 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -142,7 +142,17 @@ def init_superuser(): "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", -}, +},{ + "name": "Jina", + "logo": "", + "tags": "TEXT EMBEDDING, TEXT RE-RANK", + "status": "1", +},{ + "name": "BAAI", + "logo": "", + "tags": "TEXT EMBEDDING, TEXT RE-RANK", + "status": "1", +} # { # "name": "文心一言", # "logo": "", @@ -367,6 +377,13 @@ def init_llm_factory(): "max_tokens": 512, "model_type": LLMType.EMBEDDING.value }, + { + "fid": factory_infos[7]["name"], + "llm_name": "maidalun1020/bce-reranker-base_v1", + "tags": "RE-RANK, 8K", + "max_tokens": 8196, + "model_type": LLMType.RERANK.value + }, # ------------------------ DeepSeek ----------------------- { "fid": factory_infos[8]["name"], @@ -440,6 +457,85 @@ def init_llm_factory(): "max_tokens": 512, "model_type": LLMType.EMBEDDING.value }, + # ------------------------ Jina ----------------------- + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-reranker-v1-base-en", + "tags": "RE-RANK,8k", + "max_tokens": 8196, + "model_type": LLMType.RERANK.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-reranker-v1-turbo-en", + "tags": "RE-RANK,8k", + "max_tokens": 8196, + "model_type": LLMType.RERANK.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-reranker-v1-tiny-en", + "tags": "RE-RANK,8k", + "max_tokens": 8196, + "model_type": LLMType.RERANK.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-colbert-v1-en", + "tags": "RE-RANK,8k", + "max_tokens": 8196, + "model_type": LLMType.RERANK.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-embeddings-v2-base-en", + "tags": "TEXT EMBEDDING", + "max_tokens": 8196, + "model_type": LLMType.EMBEDDING.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-embeddings-v2-base-de", + "tags": "TEXT EMBEDDING", + "max_tokens": 8196, + "model_type": LLMType.EMBEDDING.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-embeddings-v2-base-es", + "tags": "TEXT EMBEDDING", + "max_tokens": 8196, + "model_type": LLMType.EMBEDDING.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-embeddings-v2-base-code", + "tags": "TEXT EMBEDDING", + "max_tokens": 8196, + "model_type": LLMType.EMBEDDING.value + }, + { + "fid": factory_infos[11]["name"], + "llm_name": "jina-embeddings-v2-base-zh", + "tags": "TEXT EMBEDDING", + "max_tokens": 8196, + "model_type": LLMType.EMBEDDING.value + }, + # ------------------------ BAAI ----------------------- + { + "fid": factory_infos[12]["name"], + "llm_name": "BAAI/bge-large-zh-v1.5", + "tags": "TEXT EMBEDDING,", + "max_tokens": 1024, + "model_type": LLMType.EMBEDDING.value + }, + { + "fid": factory_infos[12]["name"], + "llm_name": "BAAI/bge-reranker-v2-m3", + "tags": "LLM,CHAT,", + "max_tokens": 16385, + "model_type": LLMType.RERANK.value + }, ] for info in factory_infos: try: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f5beb9480a5..a4f9601589c 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs): if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} else: + rerank_mdl = None + if dialog.rerank_id: + rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, - top=1024, aggs=False) + top=1024, aggs=False, rerank_mdl=rerank_mdl) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] chat_logger.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) @@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs): kwargs["knowledge"] = "\n".join(knowledges) gen_conf = dialog.llm_setting - + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] msg.extend([{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 0dd3b8374d8..5a71ea69aba 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -15,7 +15,7 @@ # from api.db.services.user_service import TenantService from api.settings import database_logger -from rag.llm import EmbeddingModel, CvModel, ChatModel +from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel from api.db import LLMType from api.db.db_models import DB, UserTenant from api.db.db_models import LLMFactories, LLM, TenantLLM @@ -73,21 +73,25 @@ def model_instance(cls, tenant_id, llm_type, mdlnm = tenant.img2txt_id elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id if not llm_name else llm_name + elif llm_type == LLMType.RERANK: + mdlnm = tenant.rerank_id if not llm_name else llm_name else: assert False, "LLM type error" model_config = cls.get_api_key(tenant_id, mdlnm) if model_config: model_config = model_config.to_dict() if not model_config: - if llm_type == LLMType.EMBEDDING.value: + if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]: llm = LLMService.query(llm_name=llm_name) - if llm and llm[0].fid in ["Youdao", "FastEmbed", "DeepSeek"]: + if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]: model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} if not model_config: if llm_name == "flag-embedding": model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} else: + if not mdlnm: + raise LookupError(f"Type of {llm_type} model is not set.") raise LookupError("Model({}) not authorized".format(mdlnm)) if llm_type == LLMType.EMBEDDING.value: @@ -96,6 +100,12 @@ def model_instance(cls, tenant_id, llm_type, return EmbeddingModel[model_config["llm_factory"]]( model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + if llm_type == LLMType.RERANK: + if model_config["llm_factory"] not in RerankModel: + return + return RerankModel[model_config["llm_factory"]]( + model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + if llm_type == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: return @@ -125,14 +135,20 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): mdlnm = tenant.img2txt_id elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id if not llm_name else llm_name + elif llm_type == LLMType.RERANK: + mdlnm = tenant.llm_id if not llm_name else llm_name else: assert False, "LLM type error" num = 0 - for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): - num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ - .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ - .execute() + try: + for u in cls.query(tenant_id = tenant_id, llm_name=mdlnm): + num += cls.model.update(used_tokens = u.used_tokens + used_tokens)\ + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ + .execute() + except Exception as e: + print(e) + pass return num @classmethod @@ -176,6 +192,14 @@ def encode_queries(self, query: str): "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) return emd, used_tokens + def similarity(self, query: str, texts: list): + sim, used_tokens = self.mdl.similarity(query, texts) + if not TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, used_tokens): + database_logger.error( + "Can't update token usage for {}/RERANK".format(self.tenant_id)) + return sim, used_tokens + def describe(self, image, max_tokens=300): txt, used_tokens = self.mdl.describe(image, max_tokens) if not TenantLLMService.increase_usage( diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index 4194ff58c41..07468b814bb 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -93,6 +93,7 @@ def get_by_user_id(cls, user_id): cls.model.name, cls.model.llm_id, cls.model.embd_id, + cls.model.rerank_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, diff --git a/api/settings.py b/api/settings.py index 81139480187..769ab6cc9f5 100644 --- a/api/settings.py +++ b/api/settings.py @@ -89,9 +89,22 @@ }, "DeepSeek": { "chat_model": "deepseek-chat", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + }, + "VolcEngine": { + "chat_model": "", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + }, + "BAAI": { + "chat_model": "", "embedding_model": "BAAI/bge-large-zh-v1.5", "image2text_model": "", "asr_model": "", + "rerank_model": "BAAI/bge-reranker-v2-m3", } } LLM = get_base_config("user_default_llm", {}) @@ -104,7 +117,8 @@ f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") LLM_FACTORY = "Tongyi-Qianwen" CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] -EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] +EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] +RERANK_MDL = default_llm["BAAI"]["rerank_model"] ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 9fc114f0621..25b08921aac 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -16,18 +16,19 @@ from .embedding_model import * from .chat_model import * from .cv_model import * +from .rerank_model import * EmbeddingModel = { "Ollama": OllamaEmbed, "OpenAI": OpenAIEmbed, "Xinference": XinferenceEmbed, - "Tongyi-Qianwen": DefaultEmbedding, #QWenEmbed, + "Tongyi-Qianwen": DefaultEmbedding,#QWenEmbed, "ZHIPU-AI": ZhipuEmbed, "FastEmbed": FastEmbed, "Youdao": YoudaoEmbed, - "DeepSeek": DefaultEmbedding, - "BaiChuan": BaiChuanEmbed + "BaiChuan": BaiChuanEmbed, + "BAAI": DefaultEmbedding } @@ -52,3 +53,9 @@ "BaiChuan": BaiChuanChat } + +RerankModel = { + "BAAI": DefaultRerank, + "Jina": JinaRerank, + "Youdao": YoudaoRerank, +} diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 43485d99ca2..f3d0a872989 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import re from typing import Optional +import requests from huggingface_hub import snapshot_download from zhipuai import ZhipuAI import os @@ -26,21 +28,9 @@ import torch import numpy as np -from api.utils.file_utils import get_project_base_directory, get_home_cache_dir +from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate -try: - flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) -except Exception as e: - model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), - local_dir_use_symlinks=False) - flag_model = FlagModel(model_dir, - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) - class Base(ABC): def __init__(self, key, model_name): @@ -54,7 +44,9 @@ def encode_queries(self, text: str): class DefaultEmbedding(Base): - def __init__(self, *args, **kwargs): + _model = None + + def __init__(self, key, model_name, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -66,7 +58,18 @@ def __init__(self, *args, **kwargs): ^_- """ - self.model = flag_model + if not DefaultEmbedding._model: + try: + self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available()) + except Exception as e: + model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", + local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir_use_symlinks=False) + self._model = FlagModel(model_dir, + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available()) def encode(self, texts: list, batch_size=32): texts = [truncate(t, 2048) for t in texts] @@ -75,12 +78,12 @@ def encode(self, texts: list, batch_size=32): token_count += num_tokens_from_string(t) res = [] for i in range(0, len(texts), batch_size): - res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) + res.extend(self._model.encode(texts[i:i + batch_size]).tolist()) return np.array(res), token_count def encode_queries(self, text: str): token_count = num_tokens_from_string(text) - return self.model.encode_queries([text]).tolist()[0], token_count + return self._model.encode_queries([text]).tolist()[0], token_count class OpenAIEmbed(Base): @@ -189,16 +192,19 @@ def encode_queries(self, text): class FastEmbed(Base): + _model = None + def __init__( - self, - key: Optional[str] = None, - model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: Optional[str] = None, - threads: Optional[int] = None, - **kwargs, + self, + key: Optional[str] = None, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, ): from fastembed import TextEmbedding - self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) + if not FastEmbed._model: + self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) def encode(self, texts: list, batch_size=32): # Using the internal tokenizer to encode the texts and get the total @@ -265,3 +271,29 @@ def encode(self, texts: list, batch_size=10): def encode_queries(self, text): embds = YoudaoEmbed._client.encode([text]) return np.array(embds[0]), num_tokens_from_string(text) + + +class JinaEmbed(Base): + def __init__(self, key, model_name="jina-embeddings-v2-base-zh", + base_url="https://api.jina.ai/v1/embeddings"): + + self.base_url = "https://api.jina.ai/v1/embeddings" + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {key}" + } + self.model_name = model_name + + def encode(self, texts: list, batch_size=None): + texts = [truncate(t, 8196) for t in texts] + data = { + "model": self.model_name, + "input": texts, + 'encoding_type': 'float' + } + res = requests.post(self.base_url, headers=self.headers, json=data) + return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] + + def encode_queries(self, text): + embds, cnt = self.encode([text]) + return np.array(embds[0]), cnt \ No newline at end of file diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py new file mode 100644 index 00000000000..0f4440c3fd1 --- /dev/null +++ b/rag/llm/rerank_model.py @@ -0,0 +1,113 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +import requests +import torch +from FlagEmbedding import FlagReranker +from huggingface_hub import snapshot_download +import os +from abc import ABC +import numpy as np +from api.utils.file_utils import get_home_cache_dir +from rag.utils import num_tokens_from_string, truncate + + +class Base(ABC): + def __init__(self, key, model_name): + pass + + def similarity(self, query: str, texts: list): + raise NotImplementedError("Please implement encode method!") + + +class DefaultRerank(Base): + _model = None + + def __init__(self, key, model_name, **kwargs): + """ + If you have trouble downloading HuggingFace models, -_^ this might help!! + + For Linux: + export HF_ENDPOINT=https://hf-mirror.com + + For Windows: + Good luck + ^_- + + """ + if not DefaultRerank._model: + try: + self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + use_fp16=torch.cuda.is_available()) + except Exception as e: + self._model = snapshot_download(repo_id=model_name, + local_dir=os.path.join(get_home_cache_dir(), + re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir_use_symlinks=False) + self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), + use_fp16=torch.cuda.is_available()) + + def similarity(self, query: str, texts: list): + pairs = [(query,truncate(t, 2048)) for t in texts] + token_count = 0 + for _, t in pairs: + token_count += num_tokens_from_string(t) + batch_size = 32 + res = [] + for i in range(0, len(pairs), batch_size): + scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048) + res.extend(scores) + return np.array(res), token_count + + +class JinaRerank(Base): + def __init__(self, key, model_name="jina-reranker-v1-base-en", + base_url="https://api.jina.ai/v1/rerank"): + self.base_url = "https://api.jina.ai/v1/rerank" + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {key}" + } + self.model_name = model_name + + def similarity(self, query: str, texts: list): + texts = [truncate(t, 8196) for t in texts] + data = { + "model": self.model_name, + "query": query, + "documents": texts, + "top_n": len(texts) + } + res = requests.post(self.base_url, headers=self.headers, json=data) + return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] + + +class YoudaoRerank(DefaultRerank): + _model = None + + def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): + from BCEmbedding import RerankerModel + if not YoudaoRerank._model: + try: + print("LOADING BCE...") + YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( + get_home_cache_dir(), + re.sub(r"^[a-zA-Z]+/", "", model_name))) + except Exception as e: + YoudaoRerank._model = RerankerModel( + model_name_or_path=model_name.replace( + "maidalun1020", "InfiniFlow")) + diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 2cd78f1e5b5..07bd96f4edc 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -54,7 +54,8 @@ def question(self, txt, tbl="qa", min_match="60%"): if not self.isChinese(txt): tks = rag_tokenizer.tokenize(txt).split(" ") tks_w = self.tw.weights(tks) - q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w] + tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w] + q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] for i in range(1, len(tks_w)): q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) if not q: @@ -136,7 +137,11 @@ def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity import numpy as np sims = CosineSimilarity([avec], bvecs) + tksim = self.token_similarity(atks, btkss) + return np.array(sims[0]) * vtweight + \ + np.array(tksim) * tkweight, tksim, sims[0] + def token_similarity(self, atks, btkss): def toDict(tks): d = {} if isinstance(tks, str): @@ -149,9 +154,7 @@ def toDict(tks): atks = toDict(atks) btkss = [toDict(tks) for tks in btkss] - tksim = [self.similarity(atks, btks) for btks in btkss] - return np.array(sims[0]) * vtweight + \ - np.array(tksim) * tkweight, tksim, sims[0] + return [self.similarity(atks, btks) for btks in btkss] def similarity(self, qtwt, dtwt): if isinstance(dtwt, type("")): diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index be5b724b921..c728d72a65c 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -241,11 +241,14 @@ def maxBackward_(self, line): return self.score_(res[::-1]) + def english_normalize_(self, tks): + return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] + def tokenize(self, line): line = self._strQ2B(line).lower() line = self._tradi2simp(line) zh_num = len([1 for c in line if is_chinese(c)]) - if zh_num < len(line) * 0.2: + if zh_num == 0: return " ".join([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(line)]) arr = re.split(self.SPLIT_CHAR, line) @@ -293,7 +296,7 @@ def tokenize(self, line): i = e + 1 - res = " ".join(res) + res = " ".join(self.english_normalize_(res)) if self.DEBUG: print("[TKS]", self.merge_(res)) return self.merge_(res) @@ -336,7 +339,7 @@ def fine_grained_tokenize(self, tks): res.append(stk) - return " ".join(res) + return " ".join(self.english_normalize_(res)) def is_chinese(s): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index e0c7c855325..9afb3dedb3b 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -71,8 +71,8 @@ def add_filters(bqry): s = Search() pg = int(req.get("page", 1)) - 1 - ps = int(req.get("size", 1000)) topk = int(req.get("topk", 1024)) + ps = int(req.get("size", topk)) src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) @@ -311,6 +311,26 @@ def rerank(self, sres, query, tkweight=0.3, ins_tw, tkweight, vtweight) return sim, tksim, vtsim + def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3, + vtweight=0.7, cfield="content_ltks"): + _, keywords = self.qryr.question(query) + + for i in sres.ids: + if isinstance(sres.field[i].get("important_kwd", []), str): + sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] + ins_tw = [] + for i in sres.ids: + content_ltks = sres.field[i][cfield].split(" ") + title_tks = [t for t in sres.field[i].get("title_tks", "").split(" ") if t] + important_kwd = sres.field[i].get("important_kwd", []) + tks = content_ltks + title_tks + important_kwd + ins_tw.append(tks) + + tksim = self.qryr.token_similarity(keywords, ins_tw) + vtsim,_ = rerank_mdl.similarity(" ".join(keywords), [rmSpace(" ".join(tks)) for tks in ins_tw]) + + return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim + def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): return self.qryr.hybrid_similarity(ans_embd, ins_embd, @@ -318,17 +338,22 @@ def hybrid_similarity(self, ans_embd, ins_embd, ans, inst): rag_tokenizer.tokenize(inst).split(" ")) def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, - vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): + vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None): ranks = {"total": 0, "chunks": [], "doc_aggs": {}} if not question: return ranks req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size, "question": question, "vector": True, "topk": top, - "similarity": similarity_threshold} + "similarity": similarity_threshold, + "available_int": 1} sres = self.search(req, index_name(tenant_id), embd_mdl) - sim, tsim, vsim = self.rerank( - sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + if rerank_mdl: + sim, tsim, vsim = self.rerank_by_model(rerank_mdl, + sres, question, 1 - vector_similarity_weight, vector_similarity_weight) + else: + sim, tsim, vsim = self.rerank( + sres, question, 1 - vector_similarity_weight, vector_similarity_weight) idx = np.argsort(sim * -1) dim = len(sres.query_vector)