Skip to content

Commit

Permalink
add function: upload and parse (infiniflow#1889)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#1880
### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Aug 9, 2024
1 parent 9492d70 commit ab9a43b
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 89 deletions.
2 changes: 2 additions & 0 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def completion():
if m["role"] == "assistant" and not msg:
continue
msg.append({"role": m["role"], "content": m["content"]})
if "doc_ids" in m:
msg[-1]["doc_ids"] = m["doc_ids"]
try:
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
Expand Down
204 changes: 148 additions & 56 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License
#

import datetime
import hashlib
import json
import os
import pathlib
import re
import traceback
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from io import BytesIO

import flask
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user

from api.db.db_models import Task, File
from api.db.services.dialog_service import DialogService, ConversationService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService
from graphrag.mind_map_extractor import MindMapExtractor
from rag.app import naive
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.settings import RetCode, stat_logger
from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type, thumbnail
from api.utils.web_utils import html2pdf, is_valid_url
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
from api.utils.web_utils import html2pdf, is_valid_url


Expand All @@ -65,55 +75,7 @@ def upload():
if not e:
raise LookupError("Can't find this knowledgebase!")

root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = FileService.get_kb_folder(current_user.id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])

err = []
for file in file_objs:
try:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")

filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")

location = filename
while MINIO.obj_exist(kb_id, location):
location += "_"
blob = file.read()
MINIO.put(kb_id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if doc["type"] == FileType.AURAL:
doc["parser_id"] = ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc)

FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
err.append(file.filename + ": " + str(e))
err, _ = FileService.upload_document(kb, file_objs)
if err:
return get_json_result(
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)
Expand Down Expand Up @@ -149,7 +111,7 @@ def web_crawl():
try:
filename = duplicate_name(
DocumentService.query,
name=name+".pdf",
name=name + ".pdf",
kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
Expand Down Expand Up @@ -414,7 +376,7 @@ def get(doc_id):
if not e:
return get_data_error_result(retmsg="Document not found!")

b,n = File2DocumentService.get_minio_address(doc_id=doc_id)
b, n = File2DocumentService.get_minio_address(doc_id=doc_id)
response = flask.make_response(MINIO.get(b, n))

ext = re.search(r"\.([^.]+)$", doc.name)
Expand Down Expand Up @@ -484,3 +446,133 @@ def get_image(image_id):
return response
except Exception as e:
return server_error_response(e)


@manager.route('/upload_and_parse', methods=['POST'])
@login_required
@validate_request("conversation_id")
def upload_and_parse():
req = request.json
if 'file' not in request.files:
return get_json_result(
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)

file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)

e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
e, dia = DialogService.get_by_id(conv.dialog_id)
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")

idxnm = search.index_name(kb.tenant_id)
if not ELASTICSEARCH.indexExist(idxnm):
ELASTICSEARCH.createIdx(idxnm, json.load(
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))

embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)

err, files = FileService.upload_document(kb, file_objs)
if err:
return get_json_result(
data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR)

def dummy(prog=None, msg=""):
pass

parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?。;!?", "layout_recognize": False}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for d, blob in files:
kwargs = {
"callback": dummy,
"parser_config": parser_config,
"from_page": 0,
"to_page": 100000
}
threads.append(exe.submit(naive.chunk, d["name"], blob, **kwargs))

for (docinfo,_), th in zip(files, threads):
docs = []
doc = {
"doc_id": docinfo["id"],
"kb_id": [kb.id]
}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue

output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format='JPEG')

MINIO.put(kb.id, d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["_id"])
del d["image"]
docs.append(d)

parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64

def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vects = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
vects.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
token_counts[doc_id] += c
return vects

_, tenant = TenantService.get_by_id(kb.tenant_id)
llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]

if parser_ids[doc_id] != ParserType.PICTURE.value:
mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: "+mind_map)
cks.append({
"doc_id": doc_id,
"kb_id": [kb.id],
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map"
})
except Exception as e:
stat_logger.error("Mind map generation error:", traceback.format_exc())

vects = embedding(doc_id, cks)
assert len(cks) == len(vects)
for i, d in enumerate(cks):
v = vects[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm)

DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

return get_json_result(data=[d["id"] for d in files])
10 changes: 7 additions & 3 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def chat(dialog, messages, stream=True, **kwargs):
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler

questions = [m["content"] for m in messages if m["role"] == "user"]
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]

embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
Expand Down Expand Up @@ -144,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
#self-rag
Expand All @@ -153,7 +157,7 @@ def chat(dialog, messages, stream=True, **kwargs):
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]

Expand Down
2 changes: 1 addition & 1 deletion api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from rag.utils.minio_conn import MINIO
from rag.nlp import search

from api.db import FileType, TaskStatus
from api.db import FileType, TaskStatus, ParserType
from api.db.db_models import DB, Knowledgebase, Tenant, Task
from api.db.db_models import Document
from api.db.services.common_service import CommonService
Expand Down
65 changes: 63 additions & 2 deletions api/db/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re

from flask_login import current_user
from peewee import fn

from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource
from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource, ParserType
from api.db.db_models import DB, File2Document, Knowledgebase
from api.db.db_models import File, Document
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid
from api.utils.file_utils import filename_type, thumbnail
from rag.utils.minio_conn import MINIO


class FileService(CommonService):
Expand Down Expand Up @@ -318,4 +323,60 @@ def move_file(cls, file_ids, folder_id):
cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
except Exception as e:
print(e)
raise RuntimeError("Database error (File move)!")
raise RuntimeError("Database error (File move)!")

@classmethod
@DB.connection_context()
def upload_document(self, kb, file_objs):
root_folder = self.get_root_folder(current_user.id)
pf_id = root_folder["id"]
self.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = self.get_kb_folder(current_user.id)
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])

err, files = [], []
for file in file_objs:
try:
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")

filename = duplicate_name(
DocumentService.query,
name=file.filename,
kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
raise RuntimeError("This type of file has not been supported yet!")

location = filename
while MINIO.obj_exist(kb.id, location):
location += "_"
blob = file.read()
MINIO.put(kb.id, location, blob)
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
"parser_config": kb.parser_config,
"created_by": current_user.id,
"type": filetype,
"name": filename,
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if doc["type"] == FileType.AURAL:
doc["parser_id"] = ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc)

FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
files.append((doc, blob))
except Exception as e:
err.append(file.filename + ": " + str(e))

return err, files
Loading

0 comments on commit ab9a43b

Please sign in to comment.