From 4cb0b4401f9dc7698e7be1aaf79f60a8a42060f7 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 6 Jun 2024 11:09:06 +0800 Subject: [PATCH] add self-rag --- api/apps/api_app.py | 44 ++++++------ api/apps/canvas_app.py | 112 ++++++++++++++++++++++++++++++ api/apps/conversation_app.py | 5 +- api/apps/dialog_app.py | 4 +- api/db/services/canvas_service.py | 26 +++++++ api/db/services/dialog_service.py | 69 +++++++++++++++--- deepdoc/parser/pdf_parser.py | 2 + rag/llm/rerank_model.py | 1 - rag/nlp/query.py | 9 ++- 9 files changed, 234 insertions(+), 38 deletions(-) create mode 100644 api/apps/canvas_app.py create mode 100644 api/db/services/canvas_service.py diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 683e62adc74..e7ffcdeac34 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -198,15 +198,18 @@ def fillin_conv(ans): else: conv.reference[-1] = ans["reference"] conv.message[-1] = {"role": "assistant", "content": ans["answer"]} + def rename_field(ans): + for chunk_i in ans['reference'].get('chunks', []): + chunk_i['doc_name'] = chunk_i['docnm_kwd'] + chunk_i.pop('docnm_kwd') + def stream(): nonlocal dia, msg, req, conv try: for ans in chat(dia, msg, True, **req): fillin_conv(ans) - for chunk_i in ans['reference'].get('chunks', []): - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + rename_field(rename_field) + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" API4ConversationService.append_message(conv.id, conv.to_dict()) except Exception as e: yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), @@ -554,23 +557,24 @@ def fillin_conv(ans): "content": "" } ] - for ans in chat(dia, msg, stream=False, **req): - # answer = ans - data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - - chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] - for chunk_idx in chunk_idxs[:1]: - if ans["reference"]["chunks"][chunk_idx]["img_id"]: - try: - bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = MINIO.get(bkt, nm) - data_type_picture["url"] = base64.b64encode(response).decode('utf-8') - data.append(data_type_picture) - except Exception as e: - return server_error_response(e) + ans = "" + for a in chat(dia, msg, stream=False, **req): + ans = a break + data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) + fillin_conv(ans) + API4ConversationService.append_message(conv.id, conv.to_dict()) + + chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] + for chunk_idx in chunk_idxs[:1]: + if ans["reference"]["chunks"][chunk_idx]["img_id"]: + try: + bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") + response = MINIO.get(bkt, nm) + data_type_picture["url"] = base64.b64encode(response).decode('utf-8') + data.append(data_type_picture) + except Exception as e: + return server_error_response(e) response = {"code": 200, "msg": "success", "data": data} return response diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py new file mode 100644 index 00000000000..612a05fda97 --- /dev/null +++ b/api/apps/canvas_app.py @@ -0,0 +1,112 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json + +from flask import request +from flask_login import login_required, current_user + +from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService +from api.utils import get_uuid +from api.utils.api_utils import get_json_result, server_error_response, validate_request +from graph.canvas import Canvas + + +@manager.route('/templates', methods=['GET']) +@login_required +def templates(): + return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) + + +@manager.route('/list', methods=['GET']) +@login_required +def canvas_list(): + + return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)]) + + +@manager.route('/rm', methods=['POST']) +@validate_request("canvas_ids") +@login_required +def rm(): + for i in request.json["canvas_ids"]: + UserCanvasService.delete_by_id(i) + return get_json_result(data=True) + + +@manager.route('/set', methods=['POST']) +@validate_request("dsl", "title") +@login_required +def save(): + req = request.json + req["user_id"] = current_user.id + if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) + try: + Canvas(req["dsl"]) + except Exception as e: + return server_error_response(e) + + req["dsl"] = json.loads(req["dsl"]) + if "id" not in req: + req["id"] = get_uuid() + if not UserCanvasService.save(**req): + return server_error_response("Fail to save canvas.") + else: + UserCanvasService.update_by_id(req["id"], req) + + return get_json_result(data=req) + + +@manager.route('/get/', methods=['GET']) +@login_required +def get(canvas_id): + e, c = UserCanvasService.get_by_id(canvas_id) + if not e: + return server_error_response("canvas not found.") + return get_json_result(data=c.to_dict()) + + +@manager.route('/run', methods=['POST']) +@validate_request("id", "dsl") +@login_required +def run(): + req = request.json + if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) + try: + canvas = Canvas(req["dsl"], current_user.id) + ans = canvas.run() + req["dsl"] = json.loads(str(canvas)) + UserCanvasService.update_by_id(req["id"], dsl=req["dsl"]) + return get_json_result(data=req["dsl"]) + except Exception as e: + return server_error_response(e) + + +@manager.route('/reset', methods=['POST']) +@validate_request("canvas_id") +@login_required +def reset(): + req = request.json + try: + user_canvas = UserCanvasService.get_by_id(req["canvas_id"]) + canvas = Canvas(req["dsl"], current_user.id) + canvas.reset() + req["dsl"] = json.loads(str(canvas)) + UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"]) + return get_json_result(data=req["dsl"]) + except Exception as e: + return server_error_response(e) + + diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 1400d6623ac..c3e07c6bb3e 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import request, Response, jsonify +from copy import deepcopy +from flask import request, Response from flask_login import login_required from api.db.services.dialog_service import DialogService, ConversationService, chat from api.utils.api_utils import server_error_response, get_data_error_result, validate_request @@ -121,7 +122,7 @@ def completion(): e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(retmsg="Conversation not found!") - conv.message.append(msg[-1]) + conv.message.append(deepcopy(msg[-1])) e, dia = DialogService.get_by_id(conv.dialog_id) if not e: return get_data_error_result(retmsg="Dialog not found!") diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 0ec90f0165b..ce428947ea0 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -31,8 +31,8 @@ def set_dialog(): req = request.json dialog_id = req.get("dialog_id") name = req.get("name", "New Dialog") - icon = req.get("icon", "") description = req.get("description", "A helpful Dialog") + icon = req.get("icon", "") top_n = req.get("top_n", 6) top_k = req.get("top_k", 1024) rerank_id = req.get("rerank_id", "") @@ -92,7 +92,7 @@ def set_dialog(): "rerank_id": rerank_id, "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight, - "icon": icon, + "icon": icon } if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py new file mode 100644 index 00000000000..ed2cdf63a0b --- /dev/null +++ b/api/db/services/canvas_service.py @@ -0,0 +1,26 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime +import peewee +from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas +from api.db.services.common_service import CommonService + + +class CanvasTemplateService(CommonService): + model = CanvasTemplate + +class UserCanvasService(CommonService): + model = UserCanvas diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 92121a0dc1e..3573df967ce 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -23,6 +23,7 @@ from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.settings import chat_logger, retrievaler from rag.app.resume import forbidden_select_fields4resume +from rag.nlp.rag_tokenizer import is_chinese from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder @@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs): if not llm: raise LookupError("LLM(%s) not found" % dialog.llm_id) max_tokens = 1024 - else: max_tokens = llm[0].max_tokens + else: + max_tokens = llm[0].max_tokens kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embd_nms = list(set([kb.embd_id for kb in kbs])) if len(embd_nms) != 1: @@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs): doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, top=1024, aggs=False, rerank_mdl=rerank_mdl) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + #self-rag + if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges): + questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1]) + kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None, + top=1024, aggs=False, rerank_mdl=rerank_mdl) + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + chat_logger.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) @@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs): msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] msg.extend([{"role": m["role"], "content": m["content"]} - for m in messages if m["role"] != "system"]) + for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" @@ -150,9 +162,9 @@ def decorate_answer(answer): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] - for ck in kbinfos["chunks"]], + for ck in kbinfos["chunks"]], [ck["vector"] - for ck in kbinfos["chunks"]], + for ck in kbinfos["chunks"]], embd_mdl, tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight) @@ -166,7 +178,7 @@ def decorate_answer(answer): for c in refs["chunks"]: if c.get("vector"): del c["vector"] - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: + if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" return {"answer": answer, "reference": refs} @@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): def get_table(): nonlocal sys_prompt, user_promt, question, tried_times sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { - "temperature": 0.06}) + "temperature": 0.06}) print(user_promt, sql) chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) @@ -273,17 +285,19 @@ def get_table(): # compose markdown table clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], - tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + tbl["columns"][i]["name"])) for i in + clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ - ("|------|" if docid_idx and docid_idx else "") + ("|------|" if docid_idx and docid_idx else "") rows = ["|" + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] if quota: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) - else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + else: + rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not docnm_idx: @@ -303,5 +317,40 @@ def get_table(): return { "answer": "\n".join([clmns, line, rows]), "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], - "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} + "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in + doc_aggs.items()]} } + + +def relevant(tenant_id, llm_id, question, contents: list): + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + prompt = """ + You are a grader assessing relevance of a retrieved document to a user question. + It does not need to be a stringent test. The goal is to filter out erroneous retrievals. + If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. + Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. + No other words needed except 'yes' or 'no'. + """ + if not contents:return False + contents = "Documents: \n" + " - ".join(contents) + contents = f"Question: {question}\n" + contents + if num_tokens_from_string(contents) >= chat_mdl.max_length - 4: + contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4]) + ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01}) + if ans.lower().find("yes") >= 0: return True + return False + + +def rewrite(tenant_id, llm_id, question): + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) + prompt = """ + You are an expert at query expansion to generate a paraphrasing of a question. + I can't retrieval relevant information from the knowledge base by using user's question directly. + You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase, + writing the abbreviation in its entirety, adding some extra descriptions or explanations, + changing the way of expression, translating the original question into another language (English/Chinese), etc. + And return 5 versions of question and one is from translation. + Just list the question. No other words are needed. + """ + ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8}) + return ans diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index a33c71662c9..ecfef81072a 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -1021,6 +1021,8 @@ def dfs(arr, depth): self.page_cum_height = np.cumsum(self.page_cum_height) assert len(self.page_cum_height) == len(self.page_images) + 1 + if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from, + page_to, callback) def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): self.__images__(fnm, zoomin) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 783b629685e..33ce26fc7df 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -129,4 +129,3 @@ def similarity(self, query: str, texts: list): return np.array(res), token_count - diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 3ab3106f0fb..7485f19529d 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -48,7 +48,7 @@ def isChinese(line): @staticmethod def rmWWW(txt): patts = [ - (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), + (r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ") ] @@ -68,7 +68,9 @@ def question(self, txt, tbl="qa", min_match="60%"): if not self.isChinese(txt): tks = rag_tokenizer.tokenize(txt).split(" ") tks_w = self.tw.weights(tks) - tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w] + tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] + tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] + tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] for i in range(1, len(tks_w)): q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) @@ -118,7 +120,8 @@ def need_fine_grained_tokenize(tk): if sm: tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % ( " ".join(sm), " ".join(sm)) - tms.append((tk, w)) + if tk.strip(): + tms.append((tk, w)) tms = " ".join([f"({t})^{w}" for t, w in tms])