From 5aeef104e2765641be960395c5e958d375e92410 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Wed, 11 Dec 2024 11:53:10 +0800 Subject: [PATCH] Try to reuse existing chunks. Close #3793 --- api/apps/document_app.py | 10 ++--- api/db/db_models.py | 2 + api/db/services/task_service.py | 52 ++++++++++++++++++++--- rag/svr/task_executor.py | 73 +++++++++++++++++++++++++-------- 4 files changed, 109 insertions(+), 28 deletions(-) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 92534a501d0..5c8541f7890 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -21,10 +21,10 @@ from flask import request from flask_login import login_required, current_user -from api.db.db_models import Task, File +from api.db.db_models import File from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from api.db.services.task_service import TaskService, queue_tasks +from api.db.services.task_service import queue_tasks from api.db.services.user_service import UserTenantService from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search @@ -361,11 +361,11 @@ def run(): e, doc = DocumentService.get_by_id(id) if not e: return get_data_error_result(message="Document not found!") - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if req.get("delete", False): + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: - TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() doc["tenant_id"] = tenant_id diff --git a/api/db/db_models.py b/api/db/db_models.py index 0c4d12c034c..1b69a870239 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -855,6 +855,8 @@ class Task(DataBaseModel): help_text="process message", default="") retry_count = IntegerField(default=0) + digest = TextField(null=True, help_text="task digest", default="") + chunk_ids = TextField(null=True, help_text="chunk ids", default="") class Dialog(DataBaseModel): diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 424a571ee57..ce9cca28a22 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -30,6 +30,18 @@ from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN +def trim_header_by_lines(text: str, max_length) -> str: + if len(text) <= max_length: + return text + lines = text.split("\n") + total = 0 + idx = len(lines) - 1 + for i in range(len(lines)-1, -1, -1): + if total + len(lines[i]) > max_length: + break + idx = i + text2 = "\n".join(lines[idx:]) + return text2 class TaskService(CommonService): model = Task @@ -87,6 +99,34 @@ def get_task(cls, task_id): return docs[0] + @classmethod + @DB.connection_context() + def get_task2(cls, doc_id: str, from_page: int, to_page: int): + fields = [ + cls.model.id, + cls.model.progress, + cls.model.digest, + cls.model.chunk_ids, + ] + tasks = ( + cls.model.select(*fields) + .where(cls.model.doc_id == doc_id, cls.model.from_page == from_page, cls.model.to_page == to_page) + ) + tasks = list(tasks.dicts()) + if not tasks: + return None + return tasks[0] + + @classmethod + @DB.connection_context() + def update_digest(cls, id: str, digest: str): + cls.model.update(digest=digest).where(cls.model.id == id).execute() + + @classmethod + @DB.connection_context() + def update_chunk_ids(cls, id: str, chunk_ids: str): + cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute() + @classmethod @DB.connection_context() def get_ongoing_doc_name(cls): @@ -146,9 +186,9 @@ def do_cancel(cls, id): def update_progress(cls, id, info): if os.environ.get("MACOS"): if info["progress_msg"]: - cls.model.update( - progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] - ).where(cls.model.id == id).execute() + task = cls.model.get_by_id(id) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000) + cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: cls.model.update(progress=info["progress"]).where( cls.model.id == id @@ -157,9 +197,9 @@ def update_progress(cls, id, info): with DB.lock("update_progress", -1): if info["progress_msg"]: - cls.model.update( - progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] - ).where(cls.model.id == id).execute() + task = cls.model.get_by_id(id) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000) + cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: cls.model.update(progress=info["progress"]).where( cls.model.id == id diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 902c1e31aef..79f6de57112 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -57,6 +57,7 @@ from rag.utils import rmSpace, num_tokens_from_string from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.storage_factory import STORAGE_IMPL +from rag.utils.doc_store_conn import OrderByExpr BATCH_SIZE = 64 @@ -89,6 +90,9 @@ FAILED_TASKS = 0 CURRENT_TASK = None +class TaskCanceledException(Exception): + def __init__(self, msg): + self.msg = msg def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD @@ -112,11 +116,10 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... logging.exception(f"set_progress({task_id}) got exception") close_connection() - if cancel: - if PAYLOAD: - PAYLOAD.ack() - PAYLOAD = None - os._exit(0) + if cancel and PAYLOAD: + PAYLOAD.ack() + PAYLOAD = None + raise TaskCanceledException(msg) def collect(): @@ -358,6 +361,37 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): return res, tk_count, vector_size +def reuse_prev_task_chunks(task: dict) -> bool: + md5 = hashlib.md5() + for field in ["task_type", "tenant_id", "kb_id", "doc_id", "name", "from_page", "to_page", "parser_config", "embd_id", + "language", "llm_id"]: + md5.update(str(task.get(field, "")).encode("utf-8")) + task_digest = md5.hexdigest() + TaskService.update_digest(task["id"], task_digest) + + prev_task = TaskService.get_task2(task["doc_id"], task["from_page"], task["to_page"]) + if prev_task is None: + return False + chunk_ids = prev_task["chunk_ids"] + chunk_ids = [x for x in chunk_ids.split() if x] + reusable = False + if prev_task["progress"] == 1.0 and prev_task["digest"] == task_digest and chunk_ids: + tenant_id = task["tenant_id"] + kb_ids = [task["kb_id"]] + res = settings.docStoreConn.search(["id"], [], {"id": chunk_ids}, [], OrderByExpr(), 0, len(chunk_ids), search.index_name(tenant_id), kb_ids) + dict_chunks = settings.docStoreConn.getFields(res, ["id"]) + if len(chunk_ids) == len(dict_chunks): + reusable = True + if reusable: + TaskService.update_chunk_ids(task["id"], " ".join(chunk_ids)) + return True + + if chunk_ids: + settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task["tenant_id"]), [task["kb_id"]]) + return False + + + def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"] @@ -373,6 +407,16 @@ def do_handle_task(task): # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) + + task_canceled = TaskService.do_cancel(task_id) + if task_canceled: + progress_callback(1.0, msg="Task has been canceled.") + return + reused = reuse_prev_task_chunks(task) + if reused: + progress_callback(1.0, msg="Chunks of task already exist, skip.") + return + try: # bind embedding model embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) @@ -420,6 +464,7 @@ def do_handle_task(task): progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts) logging.info(progress_message) progress_callback(msg=progress_message) + # logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}") init_kb(task, vector_size) chunk_count = len(set([chunk["id"] for chunk in chunks])) @@ -430,23 +475,17 @@ def do_handle_task(task): doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id) if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") - logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts)) - if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" - progress_callback(-1, msg=error_message) - settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id) - logging.error(error_message) - raise Exception(error_message) - - if TaskService.do_cancel(task_id): - settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id) - return + if doc_store_result: + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + progress_callback(-1, msg=error_message) + raise Exception(error_message) + logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) time_cost = timer() - start_ts progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) - logging.info("Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(task_id, token_count, len(chunks), time_cost)) + logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost)) def handle_task():