From da6a4a11cc2142b992b95f4f58cffcd9c4602c14 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Mon, 15 Apr 2024 14:43:44 +0800 Subject: [PATCH] conversation API backend update (#360) ### What problem does this PR solve? Issue link:#345 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/__init__.py | 14 +- api/apps/api_app.py | 196 ++++++++++++++++++ api/apps/chunk_app.py | 2 +- api/apps/conversation_app.py | 254 +---------------------- api/apps/user_app.py | 2 +- api/db/db_models.py | 34 ++-- api/db/services/api_service.py | 66 ++++++ api/db/services/dialog_service.py | 253 +++++++++++++++++++++++ api/db/services/document_service.py | 2 +- docs/conversation_api.md | 303 ++++++++++++++++++++++++++++ rag/llm/chat_model.py | 53 +---- 11 files changed, 853 insertions(+), 326 deletions(-) create mode 100644 api/apps/api_app.py create mode 100644 api/db/services/api_service.py create mode 100644 docs/conversation_api.md diff --git a/api/apps/__init__.py b/api/apps/__init__.py index fcd39592d5..38f54069f7 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -14,11 +14,11 @@ # limitations under the License. # import logging -import sys import os +import sys from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from flask import Blueprint, Flask, request +from flask import Blueprint, Flask from werkzeug.wrappers.request import Request from flask_cors import CORS @@ -29,9 +29,9 @@ from flask_session import Session from flask_login import LoginManager -from api.settings import RetCode, SECRET_KEY, stat_logger -from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger -from api.utils.api_utils import get_json_result, server_error_response +from api.settings import SECRET_KEY, stat_logger +from api.settings import API_VERSION, access_logger +from api.utils.api_utils import server_error_response from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer __all__ = ['app'] @@ -54,8 +54,8 @@ #app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False app.config["SESSION_TYPE"] = "filesystem" -#app.config['MAX_CONTENT_LENGTH'] = 128 * 1024 * 1024 app.config['MAX_CONTENT_LENGTH'] = os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024) + Session(app) login_manager = LoginManager() login_manager.init_app(app) @@ -117,4 +117,4 @@ def load_user(web_request): @app.teardown_request def _db_close(exc): - close_connection() + close_connection() \ No newline at end of file diff --git a/api/apps/api_app.py b/api/apps/api_app.py new file mode 100644 index 0000000000..f294272232 --- /dev/null +++ b/api/apps/api_app.py @@ -0,0 +1,196 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime, timedelta +from flask import request +from flask_login import login_required, current_user +from api.db.db_models import APIToken, API4Conversation +from api.db.services.api_service import APITokenService, API4ConversationService +from api.db.services.dialog_service import DialogService, chat +from api.db.services.user_service import UserTenantService +from api.settings import RetCode +from api.utils import get_uuid, current_timestamp, datetime_format +from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request +from itsdangerous import URLSafeTimedSerializer + + +def generate_confirmation_token(tenent_id): + serializer = URLSafeTimedSerializer(tenent_id) + return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] + + +@manager.route('/new_token', methods=['POST']) +@validate_request("dialog_id") +@login_required +def new_token(): + req = request.json + try: + tenants = UserTenantService.query(user_id=current_user.id) + if not tenants: + return get_data_error_result(retmsg="Tenant not found!") + + tenant_id = tenants[0].tenant_id + obj = {"tenant_id": tenant_id, "token": generate_confirmation_token(tenant_id), + "dialog_id": req["dialog_id"], + "create_time": current_timestamp(), + "create_date": datetime_format(datetime.now()), + "update_time": None, + "update_date": None + } + if not APITokenService.save(**obj): + return get_data_error_result(retmsg="Fail to new a dialog!") + + return get_json_result(data=obj) + except Exception as e: + return server_error_response(e) + + +@manager.route('/token_list', methods=['GET']) +@login_required +def token_list(): + try: + tenants = UserTenantService.query(user_id=current_user.id) + if not tenants: + return get_data_error_result(retmsg="Tenant not found!") + + objs = APITokenService.query(tenant_id=tenants[0].tenant_id, dialog_id=request.args["dialog_id"]) + return get_json_result(data=[o.to_dict() for o in objs]) + except Exception as e: + return server_error_response(e) + + +@manager.route('/rm', methods=['POST']) +@validate_request("tokens", "tenant_id") +@login_required +def rm(): + req = request.json + try: + for token in req["tokens"]: + APITokenService.filter_delete( + [APIToken.tenant_id == req["tenant_id"], APIToken.token == token]) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) + + +@manager.route('/stats', methods=['GET']) +@login_required +def stats(): + try: + tenants = UserTenantService.query(user_id=current_user.id) + if not tenants: + return get_data_error_result(retmsg="Tenant not found!") + objs = API4ConversationService.stats( + tenants[0].tenant_id, + request.args.get( + "from_date", + (datetime.now() - + timedelta( + days=7)).strftime("%Y-%m-%d 24:00:00")), + request.args.get( + "to_date", + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + res = { + "pv": [(o["dt"], o["pv"]) for o in objs], + "uv": [(o["dt"], o["uv"]) for o in objs], + "speed": [(o["dt"], o["tokens"]/o["duration"]) for o in objs], + "tokens": [(o["dt"], o["tokens"]/1000.) for o in objs], + "round": [(o["dt"], o["round"]) for o in objs], + "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] + } + return get_json_result(data=res) + except Exception as e: + return server_error_response(e) + + +@manager.route('/new_conversation', methods=['POST']) +@validate_request("user_id") +def set_conversation(): + token = request.headers.get('Authorization').split()[1] + objs = APIToken.query(token=token) + if not objs: + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + req = request.json + try: + e, dia = DialogService.get_by_id(objs[0].dialog_id) + if not e: + return get_data_error_result(retmsg="Dialog not found") + conv = { + "id": get_uuid(), + "dialog_id": dia.id, + "user_id": req["user_id"], + "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] + } + API4ConversationService.save(**conv) + e, conv = API4ConversationService.get_by_id(conv["id"]) + if not e: + return get_data_error_result(retmsg="Fail to new a conversation!") + conv = conv.to_dict() + return get_json_result(data=conv) + except Exception as e: + return server_error_response(e) + + +@manager.route('/completion', methods=['POST']) +@validate_request("conversation_id", "messages") +def completion(): + token = request.headers.get('Authorization').split()[1] + if not APIToken.query(token=token): + return get_json_result( + data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) + req = request.json + e, conv = API4ConversationService.get_by_id(req["conversation_id"]) + if not e: + return get_data_error_result(retmsg="Conversation not found!") + + msg = [] + for m in req["messages"]: + if m["role"] == "system": + continue + if m["role"] == "assistant" and not msg: + continue + msg.append({"role": m["role"], "content": m["content"]}) + + try: + conv.message.append(msg[-1]) + e, dia = DialogService.get_by_id(conv.dialog_id) + if not e: + return get_data_error_result(retmsg="Dialog not found!") + del req["conversation_id"] + del req["messages"] + ans = chat(dia, msg, **req) + if not conv.reference: + conv.reference = [] + conv.reference.append(ans["reference"]) + conv.message.append({"role": "assistant", "content": ans["answer"]}) + API4ConversationService.append_message(conv.id, conv.to_dict()) + APITokenService.APITokenService(token) + return get_json_result(data=ans) + except Exception as e: + return server_error_response(e) + + +@manager.route('/conversation/', methods=['GET']) +# @login_required +def get(conversation_id): + try: + e, conv = API4ConversationService.get_by_id(conversation_id) + if not e: + return get_data_error_result(retmsg="Conversation not found!") + + return get_json_result(data=conv.to_dict()) + except Exception as e: + return server_error_response(e) \ No newline at end of file diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index b3cf11aa1d..c0315cbabe 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -60,7 +60,7 @@ def list(): for id in sres.ids: d = { "chunk_id": id, - "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get( + "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get( "content_with_weight", ""), "doc_id": sres.field[id]["doc_id"], "docnm_kwd": sres.field[id]["docnm_kwd"], diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 3ee943944e..2bb813cb6f 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -13,21 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import re - from flask import request from flask_login import login_required -from api.db.services.dialog_service import DialogService, ConversationService -from api.db import LLMType -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMService, LLMBundle, TenantLLMService -from api.settings import access_logger, stat_logger, retrievaler, chat_logger +from api.db.services.dialog_service import DialogService, ConversationService, chat from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result -from rag.app.resume import forbidden_select_fields4resume -from rag.nlp.search import index_name -from rag.utils import num_tokens_from_string, encoder, rmSpace @manager.route('/set', methods=['POST']) @@ -110,43 +101,6 @@ def list_convsersation(): return server_error_response(e) -def message_fit_in(msg, max_length=4000): - def count(): - nonlocal msg - tks_cnts = [] - for m in msg: - tks_cnts.append( - {"role": m["role"], "count": num_tokens_from_string(m["content"])}) - total = 0 - for m in tks_cnts: - total += m["count"] - return total - - c = count() - if c < max_length: - return c, msg - - msg_ = [m for m in msg[:-1] if m["role"] == "system"] - msg_.append(msg[-1]) - msg = msg_ - c = count() - if c < max_length: - return c, msg - - ll = num_tokens_from_string(msg_[0].content) - l = num_tokens_from_string(msg_[-1].content) - if ll / (ll + l) > 0.8: - m = msg_[0].content - m = encoder.decode(encoder.encode(m)[:max_length - l]) - msg[0].content = m - return max_length, msg - - m = msg_[1].content - m = encoder.decode(encoder.encode(m)[:max_length - l]) - msg[1].content = m - return max_length, msg - - @manager.route('/completion', methods=['POST']) @login_required @validate_request("conversation_id", "messages") @@ -179,209 +133,3 @@ def completion(): except Exception as e: return server_error_response(e) - -def chat(dialog, messages, **kwargs): - assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." - llm = LLMService.query(llm_name=dialog.llm_id) - if not llm: - llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id) - if not llm: - raise LookupError("LLM(%s) not found" % dialog.llm_id) - max_tokens = 1024 - else: max_tokens = llm[0].max_tokens - questions = [m["content"] for m in messages if m["role"] == "user"] - embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) - chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) - - prompt_config = dialog.prompt_config - field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) - # try to use sql if field mapping is good to go - if field_map: - chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) - if ans: return ans - - for p in prompt_config["parameters"]: - if p["key"] == "knowledge": - continue - if p["key"] not in kwargs and not p["optional"]: - raise KeyError("Miss parameter: " + p["key"]) - if p["key"] not in kwargs: - prompt_config["system"] = prompt_config["system"].replace( - "{%s}" % p["key"], " ") - - for _ in range(len(questions) // 2): - questions.append(questions[-1]) - if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: - kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} - else: - kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, top=1024, aggs=False) - knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - chat_logger.info( - "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) - - if not knowledges and prompt_config.get("empty_response"): - return { - "answer": prompt_config["empty_response"], "reference": kbinfos} - - kwargs["knowledge"] = "\n".join(knowledges) - gen_conf = dialog.llm_setting - msg = [{"role": m["role"], "content": m["content"]} - for m in messages if m["role"] != "system"] - used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) - if "max_tokens" in gen_conf: - gen_conf["max_tokens"] = min( - gen_conf["max_tokens"], - max_tokens - used_token_count) - answer = chat_mdl.chat( - prompt_config["system"].format( - **kwargs), msg, gen_conf) - chat_logger.info("User: {}|Assistant: {}".format( - msg[-1]["content"], answer)) - - if knowledges and prompt_config.get("quote", True): - answer, idx = retrievaler.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=1 - dialog.vector_similarity_weight, - vtweight=dialog.vector_similarity_weight) - idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) - recall_docs = [ - d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] - if not recall_docs: recall_docs = kbinfos["doc_aggs"] - kbinfos["doc_aggs"] = recall_docs - - for c in kbinfos["chunks"]: - if c.get("vector"): - del c["vector"] - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: - answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - return {"answer": answer, "reference": kbinfos} - - -def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): - sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" - user_promt = """ -表名:{}; -数据库表字段说明如下: -{} - -问题如下: -{} -请写出SQL, 且只要SQL,不要有其他说明及文字。 -""".format( - index_name(tenant_id), - "\n".join([f"{k}: {v}" for k, v in field_map.items()]), - question - ) - tried_times = 0 - - def get_table(): - nonlocal sys_prompt, user_promt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { - "temperature": 0.06}) - print(user_promt, sql) - chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") - sql = re.sub(r"[\r\n]+", " ", sql.lower()) - sql = re.sub(r".*select ", "select ", sql.lower()) - sql = re.sub(r" +", " ", sql) - sql = re.sub(r"([;;]|```).*", "", sql) - if sql[:len("select ")] != "select ": - return None, None - if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): - if sql[:len("select *")] != "select *": - sql = "select doc_id,docnm_kwd," + sql[6:] - else: - flds = [] - for k in field_map.keys(): - if k in forbidden_select_fields4resume: - continue - if len(flds) > 11: - break - flds.append(k) - sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] - - print(f"“{question}” get SQL(refined): {sql}") - - chat_logger.info(f"“{question}” get SQL(refined): {sql}") - tried_times += 1 - return retrievaler.sql_retrieval(sql, format="json"), sql - - tbl, sql = get_table() - if tbl is None: - return None - if tbl.get("error") and tried_times <= 2: - user_promt = """ - 表名:{}; - 数据库表字段说明如下: - {} - - 问题如下: - {} - - 你上一次给出的错误SQL如下: - {} - - 后台报错如下: - {} - - 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 - """.format( - index_name(tenant_id), - "\n".join([f"{k}: {v}" for k, v in field_map.items()]), - question, sql, tbl["error"] - ) - tbl, sql = get_table() - chat_logger.info("TRY it again: {}".format(sql)) - - chat_logger.info("GET table: {}".format(tbl)) - print(tbl) - if tbl.get("error") or len(tbl["rows"]) == 0: - return None - - docid_idx = set([ii for ii, c in enumerate( - tbl["columns"]) if c["name"] == "doc_id"]) - docnm_idx = set([ii for ii, c in enumerate( - tbl["columns"]) if c["name"] == "docnm_kwd"]) - clmn_idx = [ii for ii in range( - len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] - - # compose markdown table - clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], - tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") - - line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ - ("|------|" if docid_idx and docid_idx else "") - - rows = ["|" + - "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + - "|" for r in tbl["rows"]] - if quota: - rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) - else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) - rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) - - if not docid_idx or not docnm_idx: - chat_logger.warning("SQL missing field: " + sql) - return { - "answer": "\n".join([clmns, line, rows]), - "reference": {"chunks": [], "doc_aggs": []} - } - - docid_idx = list(docid_idx)[0] - docnm_idx = list(docnm_idx)[0] - doc_aggs = {} - for r in tbl["rows"]: - if r[docid_idx] not in doc_aggs: - doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0} - doc_aggs[r[docid_idx]]["count"] += 1 - return { - "answer": "\n".join([clmns, line, rows]), - "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], - "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} - } diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 47ed1c9569..346c63dc8d 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -15,7 +15,7 @@ # import re -from flask import request, session, redirect, url_for +from flask import request, session, redirect from werkzeug.security import generate_password_hash, check_password_hash from flask_login import login_required, current_user, login_user, logout_user diff --git a/api/db/db_models.py b/api/db/db_models.py index 6d7adf0d22..27ad80f99d 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -728,15 +728,6 @@ class Meta: db_table = "dialog" -# class DialogKb(DataBaseModel): -# dialog_id = CharField(max_length=32, null=False, index=True) -# kb_id = CharField(max_length=32, null=False) -# -# class Meta: -# db_table = "dialog_kb" -# primary_key = CompositeKey('dialog_id', 'kb_id') - - class Conversation(DataBaseModel): id = CharField(max_length=32, primary_key=True) dialog_id = CharField(max_length=32, null=False, index=True) @@ -748,13 +739,26 @@ class Meta: db_table = "conversation" -""" +class APIToken(DataBaseModel): + tenant_id = CharField(max_length=32, null=False) + token = CharField(max_length=255, null=False) + dialog_id = CharField(max_length=32, null=False, index=True) class Meta: - db_table = 't_pipeline_component_meta' - indexes = ( - (('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True), - ) + db_table = "api_token" + primary_key = CompositeKey('tenant_id', 'token') -""" +class API4Conversation(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + dialog_id = CharField(max_length=32, null=False, index=True) + user_id = CharField(max_length=255, null=False, help_text="user_id") + message = JSONField(null=True) + reference = JSONField(null=True, default=[]) + tokens = IntegerField(default=0) + duration = FloatField(default=0) + round = IntegerField(default=0) + thumb_up = IntegerField(default=0) + + class Meta: + db_table = "api_4_conversation" diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py new file mode 100644 index 0000000000..772ee97d89 --- /dev/null +++ b/api/db/services/api_service.py @@ -0,0 +1,66 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime +import peewee +from api.db.db_models import DB, API4Conversation, APIToken, Dialog +from api.db.services.common_service import CommonService +from api.utils import current_timestamp, datetime_format + + +class APITokenService(CommonService): + model = APIToken + + @classmethod + @DB.connection_context() + def used(cls, token): + return cls.model.update({ + "update_time": current_timestamp(), + "update_date": datetime_format(datetime.now()), + }).where( + cls.model.token == token + ) + + +class API4ConversationService(CommonService): + model = API4Conversation + + @classmethod + @DB.connection_context() + def append_message(cls, id, conversation): + cls.model.update_by_id(id, conversation) + return cls.model.update(round=cls.model.round + 1).where(id=id).execute() + + @classmethod + @DB.connection_context() + def stats(cls, tenant_id, from_date, to_date): + return cls.model.select( + cls.model.create_date.truncate("day").alias("dt"), + peewee.fn.COUNT( + cls.model.id).alias("pv"), + peewee.fn.COUNT( + cls.model.user_id.distinct()).alias("uv"), + peewee.fn.SUM( + cls.model.tokens).alias("tokens"), + peewee.fn.SUM( + cls.model.duration).alias("duration"), + peewee.fn.AVG( + cls.model.round).alias("round"), + peewee.fn.SUM( + cls.model.thumb_up).alias("thumb_up") + ).join(Dialog, on=(cls.model.dialog_id == Dialog.id & Dialog.tenant_id == tenant_id)).where( + cls.model.create_date >= from_date, + cls.model.create_date <= to_date + ).group_by(cls.model.create_date.truncate("day")).dicts() diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index bb770eb2ce..6e4855c4ad 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -13,8 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import re + +from api.db import LLMType from api.db.db_models import Dialog, Conversation from api.db.services.common_service import CommonService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle +from api.settings import chat_logger, retrievaler +from rag.app.resume import forbidden_select_fields4resume +from rag.nlp.search import index_name +from rag.utils import rmSpace, num_tokens_from_string, encoder class DialogService(CommonService): @@ -23,3 +32,247 @@ class DialogService(CommonService): class ConversationService(CommonService): model = Conversation + + +def message_fit_in(msg, max_length=4000): + def count(): + nonlocal msg + tks_cnts = [] + for m in msg: + tks_cnts.append( + {"role": m["role"], "count": num_tokens_from_string(m["content"])}) + total = 0 + for m in tks_cnts: + total += m["count"] + return total + + c = count() + if c < max_length: + return c, msg + + msg_ = [m for m in msg[:-1] if m["role"] == "system"] + msg_.append(msg[-1]) + msg = msg_ + c = count() + if c < max_length: + return c, msg + + ll = num_tokens_from_string(msg_[0].content) + l = num_tokens_from_string(msg_[-1].content) + if ll / (ll + l) > 0.8: + m = msg_[0].content + m = encoder.decode(encoder.encode(m)[:max_length - l]) + msg[0].content = m + return max_length, msg + + m = msg_[1].content + m = encoder.decode(encoder.encode(m)[:max_length - l]) + msg[1].content = m + return max_length, msg + + +def chat(dialog, messages, **kwargs): + assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." + llm = LLMService.query(llm_name=dialog.llm_id) + if not llm: + llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id) + if not llm: + raise LookupError("LLM(%s) not found" % dialog.llm_id) + max_tokens = 1024 + else: max_tokens = llm[0].max_tokens + questions = [m["content"] for m in messages if m["role"] == "user"] + embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING) + chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) + + prompt_config = dialog.prompt_config + field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) + # try to use sql if field mapping is good to go + if field_map: + chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) + ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) + if ans: return ans + + for p in prompt_config["parameters"]: + if p["key"] == "knowledge": + continue + if p["key"] not in kwargs and not p["optional"]: + raise KeyError("Miss parameter: " + p["key"]) + if p["key"] not in kwargs: + prompt_config["system"] = prompt_config["system"].replace( + "{%s}" % p["key"], " ") + + for _ in range(len(questions) // 2): + questions.append(questions[-1]) + if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: + kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} + else: + kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, top=1024, aggs=False) + knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] + chat_logger.info( + "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + + if not knowledges and prompt_config.get("empty_response"): + return { + "answer": prompt_config["empty_response"], "reference": kbinfos} + + kwargs["knowledge"] = "\n".join(knowledges) + gen_conf = dialog.llm_setting + msg = [{"role": m["role"], "content": m["content"]} + for m in messages if m["role"] != "system"] + used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) + if "max_tokens" in gen_conf: + gen_conf["max_tokens"] = min( + gen_conf["max_tokens"], + max_tokens - used_token_count) + answer = chat_mdl.chat( + prompt_config["system"].format( + **kwargs), msg, gen_conf) + chat_logger.info("User: {}|Assistant: {}".format( + msg[-1]["content"], answer)) + + if knowledges and prompt_config.get("quote", True): + answer, idx = retrievaler.insert_citations(answer, + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=1 - dialog.vector_similarity_weight, + vtweight=dialog.vector_similarity_weight) + idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) + recall_docs = [ + d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + if not recall_docs: recall_docs = kbinfos["doc_aggs"] + kbinfos["doc_aggs"] = recall_docs + + for c in kbinfos["chunks"]: + if c.get("vector"): + del c["vector"] + if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: + answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" + return {"answer": answer, "reference": kbinfos} + + +def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): + sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。" + user_promt = """ +表名:{}; +数据库表字段说明如下: +{} + +问题如下: +{} +请写出SQL, 且只要SQL,不要有其他说明及文字。 +""".format( + index_name(tenant_id), + "\n".join([f"{k}: {v}" for k, v in field_map.items()]), + question + ) + tried_times = 0 + + def get_table(): + nonlocal sys_prompt, user_promt, question, tried_times + sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { + "temperature": 0.06}) + print(user_promt, sql) + chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") + sql = re.sub(r"[\r\n]+", " ", sql.lower()) + sql = re.sub(r".*select ", "select ", sql.lower()) + sql = re.sub(r" +", " ", sql) + sql = re.sub(r"([;;]|```).*", "", sql) + if sql[:len("select ")] != "select ": + return None, None + if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): + if sql[:len("select *")] != "select *": + sql = "select doc_id,docnm_kwd," + sql[6:] + else: + flds = [] + for k in field_map.keys(): + if k in forbidden_select_fields4resume: + continue + if len(flds) > 11: + break + flds.append(k) + sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] + + print(f"“{question}” get SQL(refined): {sql}") + + chat_logger.info(f"“{question}” get SQL(refined): {sql}") + tried_times += 1 + return retrievaler.sql_retrieval(sql, format="json"), sql + + tbl, sql = get_table() + if tbl is None: + return None + if tbl.get("error") and tried_times <= 2: + user_promt = """ + 表名:{}; + 数据库表字段说明如下: + {} + + 问题如下: + {} + + 你上一次给出的错误SQL如下: + {} + + 后台报错如下: + {} + + 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 + """.format( + index_name(tenant_id), + "\n".join([f"{k}: {v}" for k, v in field_map.items()]), + question, sql, tbl["error"] + ) + tbl, sql = get_table() + chat_logger.info("TRY it again: {}".format(sql)) + + chat_logger.info("GET table: {}".format(tbl)) + print(tbl) + if tbl.get("error") or len(tbl["rows"]) == 0: + return None + + docid_idx = set([ii for ii, c in enumerate( + tbl["columns"]) if c["name"] == "doc_id"]) + docnm_idx = set([ii for ii, c in enumerate( + tbl["columns"]) if c["name"] == "docnm_kwd"]) + clmn_idx = [ii for ii in range( + len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] + + # compose markdown table + clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], + tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + + line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ + ("|------|" if docid_idx and docid_idx else "") + + rows = ["|" + + "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + + "|" for r in tbl["rows"]] + if quota: + rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) + rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) + + if not docid_idx or not docnm_idx: + chat_logger.warning("SQL missing field: " + sql) + return { + "answer": "\n".join([clmns, line, rows]), + "reference": {"chunks": [], "doc_aggs": []} + } + + docid_idx = list(docid_idx)[0] + docnm_idx = list(docnm_idx)[0] + doc_aggs = {} + for r in tbl["rows"]: + if r[docid_idx] not in doc_aggs: + doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0} + doc_aggs[r[docid_idx]]["count"] += 1 + return { + "answer": "\n".join([clmns, line, rows]), + "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], + "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} + } diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index bd85a94bf8..00ef9d1a49 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -15,7 +15,7 @@ # from peewee import Expression -from api.db import TenantPermission, FileType, TaskStatus +from api.db import FileType, TaskStatus from api.db.db_models import DB, Knowledgebase, Tenant from api.db.db_models import Document from api.db.services.common_service import CommonService diff --git a/docs/conversation_api.md b/docs/conversation_api.md new file mode 100644 index 0000000000..2002bcd04e --- /dev/null +++ b/docs/conversation_api.md @@ -0,0 +1,303 @@ +# Conversation API Instruction + +## Base URL +```buildoutcfg +https://demo.ragflow.io/v1/ +``` + +## Authorization + +All the APIs are authorized with API-Key. Please keep it save and private. Don't reveal it in any way from the front-end. +The API-Key should put in the header of request: +```buildoutcfg +Authorization: Bearer {API_KEY} +``` + +## Start a conversation + +This should be called whenever there's new user coming to chat. +### Path: /api/new_conversation +### Method: GET +### Parameter: + +| name | type | optional | description| +|------|-------|----|----| +| user_id| string | No | It's for identifying user in order to search and calculate statistics.| + +### Response +```json +{ + "data": { + "create_date": "Fri, 12 Apr 2024 17:26:21 GMT", + "create_time": 1712913981857, + "dialog_id": "4f0a2e4cb9af11ee9ba20aef05f5e94f", + "duration": 0.0, + "id": "b9b2e098f8ae11ee9f45fa163e197198", + "message": [ + { + "content": "Hi, I'm your assistant, can I help you?", + "role": "assistant" + } + ], + "reference": [], + "tokens": 0, + "update_date": "Fri, 12 Apr 2024 17:26:21 GMT", + "update_time": 1712913981857, + "user_id": "kevinhu" + }, + "retcode": 0, + "retmsg": "success" +} +``` +> data['id'] in response should be stored and will be used in every round of following conversation. + +## Get history of a conversation + +### Path: /api/conversation/\ +### Method: GET +### Response +```json +{ + "data": { + "create_date": "Mon, 01 Apr 2024 09:28:42 GMT", + "create_time": 1711934922220, + "dialog_id": "df4a4916d7bd11eeaa650242ac180006", + "id": "2cae30fcefc711ee94140242ac180006", + "message": [ + { + "content": "Hi! I'm your assistant, what can I do for you?", + "role": "assistant" + }, + { + "content": "What's the vit score for GPT-4?", + "role": "user" + }, + { + "content": "The ViT Score for GPT-4 in the zero-shot scenario is 0.5058, and in the few-shot scenario, it is 0.6480. ##0$$", + "role": "assistant" + }, + { + "content": "How is the nvlink topology like?", + "role": "user" + }, + { + "content": "NVLink topology refers to the arrangement of connections between GPUs using NVIDIA's NVLink technology. Correct NVLink topology for NVIDIA A100 cards involves connecting one GPU to another through a series of NVLink bridges ##0$$. Each of the three attached bridges spans two PCIe slots, and for optimal performance and balanced bridge topology, all three NVLink bridges should be used when connecting two adjacent A100 cards.\n\nHere's a summary of the correct and incorrect topologies:\n\n- **Correct**: Both GPUs are connected via all three NVLink bridges, ensuring full bandwidth and proper communication.\n- **Incorrect**: Not using all three bridges or having an uneven connection configuration would result in suboptimal performance.\n\nIt's also important to note that for multi-CPU systems, both A100 cards in a bridged pair should be within the same CPU domain, unless each CPU has a single A100 PCIe card, in which case they can be bridged together.", + "role": "assistant" + } + ], + "user_id": "user name", + "reference": [ + { + "chunks": [ + { + "chunk_id": "d0bc7892c3ec4aeac071544fd56730a8", + "content_ltks": "tabl 1:openagi task-solv perform under differ set for three closed-sourc llm . boldfac denot the highest score under each learn schema . metric gpt-3.5-turbo claude-2 gpt-4 zero few zero few zero few clip score 0.0 0.0 0.0 0.2543 0.0 0.3055 bert score 0.1914 0.3820 0.2111 0.5038 0.2076 0.6307 vit score 0.2437 0.7497 0.4082 0.5416 0.5058 0.6480 overal 0.1450 0.3772 0.2064 0.4332 0.2378 0.5281", + "content_with_weight": "\n\n\n\n\n\n\n
Table 1: OpenAGI task-solving performances under different settings for three closed-source LLMs. Boldface denotes the highest score under each learning schema.
MetricsGPT-3.5-turboClaude-2GPT-4
ZeroFewZero FewZero Few
CLIP Score0.00.00.0 0.25430.0 0.3055
BERT Score0.19140.38200.2111 0.50380.2076 0.6307
ViT Score0.24370.74970.4082 0.54160.5058 0.6480
Overall0.14500.37720.2064 0.43320.2378 0.5281
", + "doc_id": "c790da40ea8911ee928e0242ac180005", + "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf", + "img_id": "afab9fdad6e511eebdb20242ac180006-d0bc7892c3ec4aeac071544fd56730a8", + "important_kwd": [], + "kb_id": "afab9fdad6e511eebdb20242ac180006", + "positions": [ + [ + 9.0, + 159.9383341471354, + 472.1773274739583, + 223.58013916015625, + 307.86692301432294 + ] + ], + "similarity": 0.7310340654129031, + "term_similarity": 0.7671974387781668, + "vector_similarity": 0.40556370512552886 + }, + { + "chunk_id": "7e2345d440383b756670e1b0f43a7007", + "content_ltks": "5.5 experiment analysi the main experiment result are tabul in tab . 1 and 2 , showcas the result for closed-sourc and open-sourc llm , respect . the overal perform is calcul a the averag of cllp 8 bert and vit score . here , onli the task descript of the benchmark task are fed into llm(addit inform , such a the input prompt and llm\u2019output , is provid in fig . a.4 and a.5 in supplementari). broadli speak , closed-sourc llm demonstr superior perform on openagi task , with gpt-4 lead the pack under both zero-and few-shot scenario . in the open-sourc categori , llama-2-13b take the lead , consist post top result across variou learn schema--the perform possibl influenc by it larger model size . notabl , open-sourc llm significantli benefit from the tune method , particularli fine-tun and\u2019rltf . these method mark notic enhanc for flan-t5-larg , vicuna-7b , and llama-2-13b when compar with zero-shot and few-shot learn schema . in fact , each of these open-sourc model hit it pinnacl under the rltf approach . conclus , with rltf tune , the perform of llama-2-13b approach that of gpt-3.5 , illustr it potenti .", + "content_with_weight": "5.5 Experimental Analysis\nThe main experimental results are tabulated in Tab. 1 and 2, showcasing the results for closed-source and open-source LLMs, respectively. The overall performance is calculated as the average of CLlP\n8\nBERT and ViT scores. Here, only the task descriptions of the benchmark tasks are fed into LLMs (additional information, such as the input prompt and LLMs\u2019 outputs, is provided in Fig. A.4 and A.5 in supplementary). Broadly speaking, closed-source LLMs demonstrate superior performance on OpenAGI tasks, with GPT-4 leading the pack under both zero- and few-shot scenarios. In the open-source category, LLaMA-2-13B takes the lead, consistently posting top results across various learning schema--the performance possibly influenced by its larger model size. Notably, open-source LLMs significantly benefit from the tuning methods, particularly Fine-tuning and\u2019 RLTF. These methods mark noticeable enhancements for Flan-T5-Large, Vicuna-7B, and LLaMA-2-13B when compared with zero-shot and few-shot learning schema. In fact, each of these open-source models hits its pinnacle under the RLTF approach. Conclusively, with RLTF tuning, the performance of LLaMA-2-13B approaches that of GPT-3.5, illustrating its potential.", + "doc_id": "c790da40ea8911ee928e0242ac180005", + "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf", + "img_id": "afab9fdad6e511eebdb20242ac180006-7e2345d440383b756670e1b0f43a7007", + "important_kwd": [], + "kb_id": "afab9fdad6e511eebdb20242ac180006", + "positions": [ + [ + 8.0, + 107.3, + 508.90000000000003, + 686.3, + 697.0 + ], + ], + "similarity": 0.6691508616357027, + "term_similarity": 0.6999011754270821, + "vector_similarity": 0.39239803751328806 + }, + ], + "doc_aggs": [ + { + "count": 8, + "doc_id": "c790da40ea8911ee928e0242ac180005", + "doc_name": "OpenAGI When LLM Meets Domain Experts.pdf" + } + ], + "total": 8 + }, + { + "chunks": [ + { + "chunk_id": "8c11a1edddb21ad2ae0c43b4a5dcfa62", + "content_ltks": "nvlink bridg support nvidia\u00aenvlink\u00aei a high-spe point-to-point peer transfer connect , where one gpu can transfer data to and receiv data from one other gpu . the nvidia a100 card support nvlink bridg connect with a singl adjac a100 card . each of the three attach bridg span two pcie slot . to function correctli a well a to provid peak bridg bandwidth , bridg connect with an adjac a100 card must incorpor all three nvlink bridg . wherev an adjac pair of a100 card exist in the server , for best bridg perform and balanc bridg topolog , the a100 pair should be bridg . figur 4 illustr correct and incorrect a100 nvlink connect topolog . nvlink topolog\u2013top view figur 4. correct incorrect correct incorrect for system that featur multipl cpu , both a100 card of a bridg card pair should be within the same cpu domain\u2014that is , under the same cpu\u2019s topolog . ensur thi benefit workload applic perform . the onli except is for dual cpu system wherein each cpu ha a singl a100 pcie card under it;in that case , the two a100 pcie card in the system may be bridg togeth . a100 nvlink speed and bandwidth are given in the follow tabl . tabl 5. a100 nvlink speed and bandwidth paramet valu total nvlink bridg support by nvidia a100 3 total nvlink rx and tx lane support 96 data rate per nvidia a100 nvlink lane(each direct)50 gbp total maximum nvlink bandwidth 600 gbyte per second pb-10137-001_v03|8 nvidia a100 40gb pcie gpu acceler", + "content_with_weight": "NVLink Bridge Support\nNVIDIA\u00aeNVLink\u00aeis a high-speed point-to-point peer transfer connection, where one GPU can transfer data to and receive data from one other GPU. The NVIDIA A100 card supports NVLink bridge connection with a single adjacent A100 card.\nEach of the three attached bridges spans two PCIe slots. To function correctly as well as to provide peak bridge bandwidth, bridge connection with an adjacent A100 card must incorporate all three NVLink bridges. Wherever an adjacent pair of A100 cards exists in the server, for best bridging performance and balanced bridge topology, the A100 pair should be bridged. Figure 4 illustrates correct and incorrect A100 NVLink connection topologies.\nNVLink Topology \u2013Top Views \nFigure 4. \nCORRECT \nINCORRECT \nCORRECT \nINCORRECT \nFor systems that feature multiple CPUs, both A100 cards of a bridged card pair should be within the same CPU domain\u2014that is, under the same CPU\u2019s topology. Ensuring this benefits workload application performance. The only exception is for dual CPU systems wherein each CPU has a single A100 PCIe card under it; in that case, the two A100 PCIe cards in the system may be bridged together.\nA100 NVLink speed and bandwidth are given in the following table.\n\n\n\n\n\n\n
Table 5. A100 NVLink Speed and Bandwidth
Parameter Value
Total NVLink bridges supported by NVIDIA A100 3
Total NVLink Rx and Tx lanes supported 96
Data rate per NVIDIA A100 NVLink lane (each direction)50 Gbps
Total maximum NVLink bandwidth600 Gbytes per second
\nPB-10137-001_v03 |8\nNVIDIA A100 40GB PCIe GPU Accelerator", + "doc_id": "806d1ed0ea9311ee860a0242ac180005", + "docnm_kwd": "A100-PCIE-Prduct-Brief.pdf", + "img_id": "afab9fdad6e511eebdb20242ac180006-8c11a1edddb21ad2ae0c43b4a5dcfa62", + "important_kwd": [], + "kb_id": "afab9fdad6e511eebdb20242ac180006", + "positions": [ + [ + 12.0, + 84.0, + 541.3, + 76.7, + 96.7 + ], + ], + "similarity": 0.3200748779905588, + "term_similarity": 0.3082244010114718, + "vector_similarity": 0.42672917080234146 + }, + ], + "doc_aggs": [ + { + "count": 1, + "doc_id": "806d1ed0ea9311ee860a0242ac180005", + "doc_name": "A100-PCIE-Prduct-Brief.pdf" + } + ], + "total": 3 + } + ], + "update_date": "Tue, 02 Apr 2024 09:07:49 GMT", + "update_time": 1712020069421 + }, + "retcode": 0, + "retmsg": "success" +} +``` + +- **message**: All the chat history in it. + - role: user or assistant + - content: the text content of user or assistant. The citations are in format like: ##0$$. The number in the middle indicate which part in data.reference.chunks it refers to. + +- **user_id**: This is set by the caller. +- **reference**: Every item in it refer to the corresponding message in data.message whose role is assistant. + - chunks + - content_with_weight: The content of chunk. + - docnm_kwd: the document name. + - img_id: the image id of the chunk. It is an optional field only for PDF/pptx/picture. And accessed by 'GET' /document/get/\. + - positions: [page_number, [upleft corner(x, y)], [right bottom(x, y)]], the chunk position, only for PDF. + - similarity: the hybrid similarity. + - term_similarity: keyword simimlarity + - vector_similarity: embedding similarity + - doc_aggs: + - doc_id: the document can be accessed by 'GET' /document/get/\ + - doc_name: the file name + - count: the chunk number hit in this document. + +## Chat + +This will be called to get the answer to users' questions. + +### Path: /api/completion +### Method: POST +### Parameter: + +| name | type | optional | description| +|------|-------|----|----| +| conversation_id| string | No | This is from calling /new_conversation.| +| messages| json | No | All the conversation history stored here including the latest user's question.| + +### Response +```json +{ + "data": { + "answer": "The ViT Score for GPT-4 in the zero-shot scenario is 0.5058, and in the few-shot scenario, it is 0.6480. ##0$$", + "reference": { + "chunks": [ + { + "chunk_id": "d0bc7892c3ec4aeac071544fd56730a8", + "content_ltks": "tabl 1:openagi task-solv perform under differ set for three closed-sourc llm . boldfac denot the highest score under each learn schema . metric gpt-3.5-turbo claude-2 gpt-4 zero few zero few zero few clip score 0.0 0.0 0.0 0.2543 0.0 0.3055 bert score 0.1914 0.3820 0.2111 0.5038 0.2076 0.6307 vit score 0.2437 0.7497 0.4082 0.5416 0.5058 0.6480 overal 0.1450 0.3772 0.2064 0.4332 0.2378 0.5281", + "content_with_weight": "\n\n\n\n\n\n\n
Table 1: OpenAGI task-solving performances under different settings for three closed-source LLMs. Boldface denotes the highest score under each learning schema.
MetricsGPT-3.5-turboClaude-2GPT-4
ZeroFewZero FewZero Few
CLIP Score0.00.00.0 0.25430.0 0.3055
BERT Score0.19140.38200.2111 0.50380.2076 0.6307
ViT Score0.24370.74970.4082 0.54160.5058 0.6480
Overall0.14500.37720.2064 0.43320.2378 0.5281
", + "doc_id": "c790da40ea8911ee928e0242ac180005", + "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf", + "img_id": "afab9fdad6e511eebdb20242ac180006-d0bc7892c3ec4aeac071544fd56730a8", + "important_kwd": [], + "kb_id": "afab9fdad6e511eebdb20242ac180006", + "positions": [ + [ + 9.0, + 159.9383341471354, + 472.1773274739583, + 223.58013916015625, + 307.86692301432294 + ] + ], + "similarity": 0.7310340654129031, + "term_similarity": 0.7671974387781668, + "vector_similarity": 0.40556370512552886 + }, + { + "chunk_id": "7e2345d440383b756670e1b0f43a7007", + "content_ltks": "5.5 experiment analysi the main experiment result are tabul in tab . 1 and 2 , showcas the result for closed-sourc and open-sourc llm , respect . the overal perform is calcul a the averag of cllp 8 bert and vit score . here , onli the task descript of the benchmark task are fed into llm(addit inform , such a the input prompt and llm\u2019output , is provid in fig . a.4 and a.5 in supplementari). broadli speak , closed-sourc llm demonstr superior perform on openagi task , with gpt-4 lead the pack under both zero-and few-shot scenario . in the open-sourc categori , llama-2-13b take the lead , consist post top result across variou learn schema--the perform possibl influenc by it larger model size . notabl , open-sourc llm significantli benefit from the tune method , particularli fine-tun and\u2019rltf . these method mark notic enhanc for flan-t5-larg , vicuna-7b , and llama-2-13b when compar with zero-shot and few-shot learn schema . in fact , each of these open-sourc model hit it pinnacl under the rltf approach . conclus , with rltf tune , the perform of llama-2-13b approach that of gpt-3.5 , illustr it potenti .", + "content_with_weight": "5.5 Experimental Analysis\nThe main experimental results are tabulated in Tab. 1 and 2, showcasing the results for closed-source and open-source LLMs, respectively. The overall performance is calculated as the average of CLlP\n8\nBERT and ViT scores. Here, only the task descriptions of the benchmark tasks are fed into LLMs (additional information, such as the input prompt and LLMs\u2019 outputs, is provided in Fig. A.4 and A.5 in supplementary). Broadly speaking, closed-source LLMs demonstrate superior performance on OpenAGI tasks, with GPT-4 leading the pack under both zero- and few-shot scenarios. In the open-source category, LLaMA-2-13B takes the lead, consistently posting top results across various learning schema--the performance possibly influenced by its larger model size. Notably, open-source LLMs significantly benefit from the tuning methods, particularly Fine-tuning and\u2019 RLTF. These methods mark noticeable enhancements for Flan-T5-Large, Vicuna-7B, and LLaMA-2-13B when compared with zero-shot and few-shot learning schema. In fact, each of these open-source models hits its pinnacle under the RLTF approach. Conclusively, with RLTF tuning, the performance of LLaMA-2-13B approaches that of GPT-3.5, illustrating its potential.", + "doc_id": "c790da40ea8911ee928e0242ac180005", + "docnm_kwd": "OpenAGI When LLM Meets Domain Experts.pdf", + "img_id": "afab9fdad6e511eebdb20242ac180006-7e2345d440383b756670e1b0f43a7007", + "important_kwd": [], + "kb_id": "afab9fdad6e511eebdb20242ac180006", + "positions": [ + [ + 8.0, + 107.3, + 508.90000000000003, + 686.3, + 697.0 + ] + ], + "similarity": 0.6691508616357027, + "term_similarity": 0.6999011754270821, + "vector_similarity": 0.39239803751328806 + } + ], + "doc_aggs": { + "OpenAGI When LLM Meets Domain Experts.pdf": 4 + }, + "total": 8 + } + }, + "retcode": 0, + "retmsg": "success" +} +``` + +- **answer**: The replay of the chat bot. +- **reference**: + - chunks: Every item in it refer to the corresponding message in answer. + - content_with_weight: The content of chunk. + - docnm_kwd: the document name. + - img_id: the image id of the chunk. It is an optional field only for PDF/pptx/picture. And accessed by 'GET' /document/get/\. + - positions: [page_number, [upleft corner(x, y)], [right bottom(x, y)]], the chunk position, only for PDF. + - similarity: the hybrid similarity. + - term_similarity: keyword simimlarity + - vector_similarity: embedding similarity + - doc_aggs: + - doc_id: the document can be accessed by 'GET' /document/get/\ + - doc_name: the file name + - count: the chunk number hit in this document. + +## Get document content or image + +This is usually used when display content of citation. +### Path: /document/get/\ +### Method: GET diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index b9bb36d736..7452931ce0 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -49,7 +49,7 @@ def chat(self, system, history, gen_conf): if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.completion_tokens + return ans, response.usage.total_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 @@ -73,7 +73,7 @@ def chat(self, system, history, gen_conf): if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.completion_tokens + return ans, response.usage.total_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 @@ -127,7 +127,7 @@ def chat(self, system, history, gen_conf): if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.completion_tokens + return ans, response.usage.total_tokens except Exception as e: return "**ERROR**: " + str(e), 0 @@ -153,7 +153,7 @@ def chat(self, system, history, gen_conf): options=options ) ans = response["message"]["content"].strip() - return ans, response["eval_count"] + return ans, response["eval_count"] + response["prompt_eval_count"] except Exception as e: return "**ERROR**: " + str(e), 0 @@ -175,50 +175,7 @@ def chat(self, system, history, gen_conf): if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - return ans, response.usage.completion_tokens + return ans, response.usage.total_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 - -class LocalLLM(Base): - class RPCProxy: - def __init__(self, host, port): - self.host = host - self.port = int(port) - self.__conn() - - def __conn(self): - from multiprocessing.connection import Client - self._connection = Client( - (self.host, self.port), authkey=b'infiniflow-token4kevinhu') - - def __getattr__(self, name): - import pickle - - def do_rpc(*args, **kwargs): - for _ in range(3): - try: - self._connection.send( - pickle.dumps((name, args, kwargs))) - return pickle.loads(self._connection.recv()) - except Exception as e: - self.__conn() - raise Exception("RPC connection lost!") - - return do_rpc - - def __init__(self, *args, **kwargs): - self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) - - def chat(self, system, history, gen_conf): - if system: - history.insert(0, {"role": "system", "content": system}) - try: - ans = self.client.chat( - history, - gen_conf - ) - return ans, num_tokens_from_string(ans) - except Exception as e: - return "**ERROR**: " + str(e), 0 -