diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index aa49aa226d0..982f6254271 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -18,6 +18,7 @@ from flask import request, Response from flask_login import login_required, current_user from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService +from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from agent.canvas import Canvas @@ -43,6 +44,10 @@ def canvas_list(): @login_required def rm(): for i in request.json["canvas_ids"]: + if not UserCanvasService.query(user_id=current_user.id,id=i): + return get_json_result( + data=False, retmsg=f'Only owner of canvas authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) UserCanvasService.delete_by_id(i) return get_json_result(data=True) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index c4529704e57..c3e05e87baa 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json from copy import deepcopy + +from db.services.user_service import UserTenantService from flask import request, Response -from flask_login import login_required,current_user +from flask_login import login_required, current_user + +from api.db import LLMType from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.llm_service import LLMBundle, TenantService -from api.db import LLMType -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.settings import RetCode from api.utils import get_uuid from api.utils.api_utils import get_json_result -import json +from api.utils.api_utils import server_error_response, get_data_error_result, validate_request @manager.route('/set', methods=['POST']) @@ -72,6 +76,14 @@ def get(): e, conv = ConversationService.get_by_id(conv_id) if not e: return get_data_error_result(retmsg="Conversation not found!") + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of conversation authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) conv = conv.to_dict() return get_json_result(data=conv) except Exception as e: @@ -84,6 +96,17 @@ def rm(): conv_ids = request.json["conversation_ids"] try: for cid in conv_ids: + exist, conv = ConversationService.get_by_id(cid) + if not exist: + return get_data_error_result(retmsg="Conversation not found!") + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of conversation authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) ConversationService.delete_by_id(cid) return get_json_result(data=True) except Exception as e: @@ -95,6 +118,10 @@ def rm(): def list_convsersation(): dialog_id = request.args["dialog_id"] try: + if not DialogService.query(tenant_id=current_user.id, id=dialog_id): + return get_json_result( + data=False, retmsg=f'Only owner of dialog authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) convs = ConversationService.query( dialog_id=dialog_id, order_by=ConversationService.model.create_time, @@ -107,12 +134,12 @@ def list_convsersation(): @manager.route('/completion', methods=['POST']) @login_required -#@validate_request("conversation_id", "messages") +@validate_request("conversation_id", "messages") def completion(): req = request.json - #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ + # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ # {"role": "user", "content": "上海有吗?"} - #]} + # ]} msg = [] for m in req["messages"]: if m["role"] == "system": @@ -141,7 +168,8 @@ def fillin_conv(ans): nonlocal conv, message_id if not conv.reference: conv.reference.append(ans["reference"]) - else: conv.reference[-1] = ans["reference"] + else: + conv.reference[-1] = ans["reference"] conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id, "prompt": ans.get("prompt", "")} ans["id"] = message_id @@ -151,13 +179,13 @@ def stream(): try: for ans in chat(dia, msg, True, **req): fillin_conv(ans) - yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), - "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, + "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" if req.get("stream", True): resp = Response(stream(), mimetype="text/event-stream") @@ -184,33 +212,34 @@ def stream(): def tts(): req = request.json text = req["text"] - + tenants = TenantService.get_by_user_id(current_user.id) if not tenants: return get_data_error_result(retmsg="Tenant not found!") - + tts_id = tenants[0]["tts_id"] if not tts_id: return get_data_error_result(retmsg="No default TTS model is set") - + tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) + def stream_audio(): try: - for chunk in tts_mdl.tts(text): - yield chunk + for chunk in tts_mdl.tts(text): + yield chunk except Exception as e: yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), - "data": {"answer": "**ERROR**: "+str(e)}}, - ensure_ascii=False)).encode('utf-8') + "data": {"answer": "**ERROR**: " + str(e)}}, + ensure_ascii=False)).encode('utf-8') - resp = Response(stream_audio(), mimetype="audio/mpeg") + resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") - + return resp - + @manager.route('/delete_msg', methods=['POST']) @login_required @validate_request("conversation_id", "message_id") @@ -224,10 +253,10 @@ def delete_msg(): for i, msg in enumerate(conv["message"]): if req["message_id"] != msg.get("id", ""): continue - assert conv["message"][i+1]["id"] == req["message_id"] + assert conv["message"][i + 1]["id"] == req["message_id"] conv["message"].pop(i) conv["message"].pop(i) - conv["reference"].pop(max(0, i//2-1)) + conv["reference"].pop(max(0, i // 2 - 1)) break ConversationService.update_by_id(conv["id"], conv) diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 5c4c2202099..371362c676f 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -19,7 +19,8 @@ from api.db.services.dialog_service import DialogService from api.db import StatusEnum from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.user_service import TenantService +from api.db.services.user_service import TenantService, UserTenantService +from api.settings import RetCode from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result @@ -164,9 +165,19 @@ def list_dialogs(): @validate_request("dialog_ids") def rm(): req = request.json + dialog_list=[] + tenants = UserTenantService.query(user_id=current_user.id) try: - DialogService.update_many_by_id( - [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) + for id in req["dialog_ids"]: + for tenant in tenants: + if DialogService.query(tenant_id=tenant.tenant_id, id=id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of dialog authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) + dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) + DialogService.update_many_by_id(dialog_list) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 8ca804fa7c2..04d4c4d0a64 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -35,7 +35,7 @@ from api.db.services.file_service import FileService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService, queue_tasks -from api.db.services.user_service import TenantService +from api.db.services.user_service import TenantService, UserTenantService from graphrag.mind_map_extractor import MindMapExtractor from rag.app import naive from rag.nlp import search @@ -189,6 +189,15 @@ def list_docs(): if not kb_id: return get_json_result( data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if KnowledgebaseService.query( + tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result( + data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + retcode=RetCode.OPERATING_ERROR) keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 1)) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 6bbd02ee5b8..7d7f86e2dca 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -100,10 +100,10 @@ def update(): def detail(): kb_id = request.args["kb_id"] try: - tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) - for m in tenants: + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: if KnowledgebaseService.query( - tenant_id=m["tenant_id"], id=kb_id): + tenant_id=tenant.tenant_id, id=kb_id): break else: return get_json_result(