Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add self-rag #1070

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,18 @@ def fillin_conv(ans):
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}

def rename_field(ans):
for chunk_i in ans['reference'].get('chunks', []):
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')

def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
for chunk_i in ans['reference'].get('chunks', []):
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
rename_field(rename_field)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
Expand Down Expand Up @@ -554,23 +557,24 @@ def fillin_conv(ans):
"content": ""
}
]
for ans in chat(dia, msg, stream=False, **req):
# answer = ans
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())

chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = MINIO.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
except Exception as e:
return server_error_response(e)
ans = ""
for a in chat(dia, msg, stream=False, **req):
ans = a
break
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())

chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = MINIO.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
except Exception as e:
return server_error_response(e)

response = {"code": 200, "msg": "success", "data": data}
return response
Expand Down
112 changes: 112 additions & 0 deletions api/apps/canvas_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json

from flask import request
from flask_login import login_required, current_user

from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request
from graph.canvas import Canvas


@manager.route('/templates', methods=['GET'])
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])


@manager.route('/list', methods=['GET'])
@login_required
def canvas_list():

return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)])


@manager.route('/rm', methods=['POST'])
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)


@manager.route('/set', methods=['POST'])
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
req["user_id"] = current_user.id
if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
try:
Canvas(req["dsl"])
except Exception as e:
return server_error_response(e)

req["dsl"] = json.loads(req["dsl"])
if "id" not in req:
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return server_error_response("Fail to save canvas.")
else:
UserCanvasService.update_by_id(req["id"], req)

return get_json_result(data=req)


@manager.route('/get/<canvas_id>', methods=['GET'])
@login_required
def get(canvas_id):
e, c = UserCanvasService.get_by_id(canvas_id)
if not e:
return server_error_response("canvas not found.")
return get_json_result(data=c.to_dict())


@manager.route('/run', methods=['POST'])
@validate_request("id", "dsl")
@login_required
def run():
req = request.json
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
try:
canvas = Canvas(req["dsl"], current_user.id)
ans = canvas.run()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], dsl=req["dsl"])
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)


@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(req["dsl"], current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)


5 changes: 3 additions & 2 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, Response, jsonify
from copy import deepcopy
from flask import request, Response
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
Expand Down Expand Up @@ -121,7 +122,7 @@ def completion():
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
conv.message.append(msg[-1])
conv.message.append(deepcopy(msg[-1]))
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
Expand Down
4 changes: 2 additions & 2 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def set_dialog():
req = request.json
dialog_id = req.get("dialog_id")
name = req.get("name", "New Dialog")
icon = req.get("icon", "")
description = req.get("description", "A helpful Dialog")
icon = req.get("icon", "")
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
Expand Down Expand Up @@ -92,7 +92,7 @@ def set_dialog():
"rerank_id": rerank_id,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"icon": icon,
"icon": icon
}
if not DialogService.save(**dia):
return get_data_error_result(retmsg="Fail to new a dialog!")
Expand Down
26 changes: 26 additions & 0 deletions api/db/services/canvas_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService


class CanvasTemplateService(CommonService):
model = CanvasTemplate

class UserCanvasService(CommonService):
model = UserCanvas
69 changes: 59 additions & 10 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.rag_tokenizer import is_chinese
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder

Expand Down Expand Up @@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024
else: max_tokens = llm[0].max_tokens
else:
max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
Expand Down Expand Up @@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
#self-rag
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]

chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))

Expand All @@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):

msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"

Expand All @@ -150,9 +162,9 @@ def decorate_answer(answer):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
Expand All @@ -166,7 +178,7 @@ def decorate_answer(answer):
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}

Expand Down Expand Up @@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06})
"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
Expand Down Expand Up @@ -273,17 +285,19 @@ def get_table():

# compose markdown table
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
tbl["columns"][i]["name"])) for i in
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")

line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
("|------|" if docid_idx and docid_idx else "")

rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)

if not docid_idx or not docnm_idx:
Expand All @@ -303,5 +317,40 @@ def get_table():
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]}
}


def relevant(tenant_id, llm_id, question, contents: list):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
if not contents:return False
contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True
return False


def rewrite(tenant_id, llm_id, question):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly.
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
changing the way of expression, translating the original question into another language (English/Chinese), etc.
And return 5 versions of question and one is from translation.
Just list the question. No other words are needed.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
return ans
2 changes: 2 additions & 0 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,8 @@ def dfs(arr, depth):

self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
page_to, callback)

def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin)
Expand Down
1 change: 0 additions & 1 deletion rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,3 @@ def similarity(self, query: str, texts: list):
return np.array(res), token_count



Loading