From aae6f182359b785fe0d57e2087390655b7fedc27 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Sun, 28 Apr 2024 09:56:16 +0800 Subject: [PATCH] fix bug about fetching file from minio --- api/apps/file_app.py | 6 +++--- api/db/services/file2document_service.py | 17 +++++++++++++++++ api/db/services/file_service.py | 2 +- api/db/services/task_service.py | 10 ++++++---- rag/svr/task_broker.py | 7 +++++-- rag/svr/task_executor.py | 5 ++++- 6 files changed, 36 insertions(+), 11 deletions(-) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 6cd9742db5..17944a9f98 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -328,12 +328,12 @@ def rename(): # @login_required def get(file_id): try: - e, doc = FileService.get_by_id(file_id) + e, file = FileService.get_by_id(file_id) if not e: return get_data_error_result(retmsg="Document not found!") - response = flask.make_response(MINIO.get(doc.parent_id, doc.location)) - ext = re.search(r"\.([^.]+)$", doc.name) + response = flask.make_response(MINIO.get(file.parent_id, file.location)) + ext = re.search(r"\.([^.]+)$", file.name) if ext: if doc.type == FileType.VISUAL.value: response.headers.set('Content-Type', 'image/%s' % ext.group(1)) diff --git a/api/db/services/file2document_service.py b/api/db/services/file2document_service.py index b53e0adffb..18ec03d316 100644 --- a/api/db/services/file2document_service.py +++ b/api/db/services/file2document_service.py @@ -18,6 +18,8 @@ from api.db.db_models import DB from api.db.db_models import File, Document, File2Document from api.db.services.common_service import CommonService +from api.db.services.document_service import DocumentService +from api.db.services.file_service import FileService from api.utils import current_timestamp, datetime_format @@ -64,3 +66,18 @@ def update_by_file_id(cls, file_id, obj): num = cls.model.update(obj).where(cls.model.id == file_id).execute() e, obj = cls.get_by_id(cls.model.id) return obj + + @classmethod + @DB.connection_context() + def get_minio_address(cls, doc_id=None, file_id=None): + if doc_id: + ids = File2DocumentService.get_by_document_id(doc_id) + else: + ids = File2DocumentService.get_by_file_id(file_id) + if ids: + e, file = FileService.get_by_id(ids[0].file_id) + return file.parent_id, file.location + else: + assert doc_id, "please specify doc_id" + e, doc = DocumentService.get_by_id(doc_id) + return doc.kb_id, doc.location diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index abb6e56c39..57948d4211 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -21,7 +21,6 @@ from api.db.db_models import File, Document from api.db.services.common_service import CommonService from api.utils import get_uuid -from rag.utils import MINIO class FileService(CommonService): @@ -241,3 +240,4 @@ def dfs(parent_id): dfs(folder_id) return size + diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 8c6bc6e8db..ccc837a038 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -15,8 +15,8 @@ # import random -from peewee import Expression -from api.db.db_models import DB +from peewee import Expression, JOIN +from api.db.db_models import DB, File2Document, File from api.db import StatusEnum, FileType, TaskStatus from api.db.db_models import Task, Document, Knowledgebase, Tenant from api.db.services.common_service import CommonService @@ -75,8 +75,10 @@ def get_tasks(cls, tm, mod=0, comm=1, items_per_page=1, takeit=True): @DB.connection_context() def get_ongoing_doc_name(cls): with DB.lock("get_task", -1): - docs = cls.model.select(*[Document.kb_id, Document.location]) \ + docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \ .join(Document, on=(cls.model.doc_id == Document.id)) \ + .join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \ + .join(File, on=(File2Document.file_id == File.id)) \ .where( Document.status == StatusEnum.VALID.value, Document.run == TaskStatus.RUNNING.value, @@ -88,7 +90,7 @@ def get_ongoing_doc_name(cls): docs = list(docs.dicts()) if not docs: return [] - return list(set([(d["kb_id"], d["location"]) for d in docs])) + return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs])) @classmethod @DB.connection_context() diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py index 3e43fbff29..d7b57d586d 100644 --- a/rag/svr/task_broker.py +++ b/rag/svr/task_broker.py @@ -20,6 +20,8 @@ from datetime import datetime from api.db.db_models import Task from api.db.db_utils import bulk_insert_into_db +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService from api.db.services.task_service import TaskService from deepdoc.parser import PdfParser from deepdoc.parser.excel_parser import HuExcelParser @@ -87,10 +89,11 @@ def new_task(): tsks = [] try: - file_bin = MINIO.get(r["kb_id"], r["location"]) + bucket, name = File2DocumentService.get_minio_address(doc_id=r["id"]) + file_bin = MINIO.get(bucket, name) if REDIS_CONN.is_alive(): try: - REDIS_CONN.set("{}/{}".format(r["kb_id"], r["location"]), file_bin, 12*60) + REDIS_CONN.set("{}/{}".format(bucket, name), file_bin, 12*60) except Exception as e: cron_logger.warning("Put into redis[EXCEPTION]:" + str(e)) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index b72b1c556b..032d9ea58b 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,6 +24,8 @@ import time import traceback from functools import partial + +from api.db.services.file2document_service import File2DocumentService from rag.utils import MINIO from api.db.db_models import close_connection from rag.settings import database_logger @@ -135,7 +137,8 @@ def build(row): pool = Pool(processes=1) try: st = timer() - thr = pool.apply_async(get_minio_binary, args=(row["kb_id"], row["location"])) + bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"]) + thr = pool.apply_async(get_minio_binary, args=(bucket, name)) binary = thr.get(timeout=90) pool.terminate() cron_logger.info(