Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Authorization checks #2221

Merged
merged 6 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api/apps/canvas_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas
Expand All @@ -43,6 +44,10 @@ def canvas_list():
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)

Expand Down
75 changes: 52 additions & 23 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from copy import deepcopy

from db.services.user_service import UserTenantService
from flask import request, Response
from flask_login import login_required,current_user
from flask_login import login_required, current_user

from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.llm_service import LLMBundle, TenantService
from api.db import LLMType
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
import json
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request


@manager.route('/set', methods=['POST'])
Expand Down Expand Up @@ -72,6 +76,14 @@ def get():
e, conv = ConversationService.get_by_id(conv_id)
if not e:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
Expand All @@ -84,6 +96,17 @@ def rm():
conv_ids = request.json["conversation_ids"]
try:
for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid)
return get_json_result(data=True)
except Exception as e:
Expand All @@ -95,6 +118,10 @@ def rm():
def list_convsersation():
dialog_id = request.args["dialog_id"]
try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
Expand All @@ -107,12 +134,12 @@ def list_convsersation():

@manager.route('/completion', methods=['POST'])
@login_required
#@validate_request("conversation_id", "messages")
# @validate_request("conversation_id", "messages")
def completion():
req = request.json
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"}
#]}
# ]}
msg = []
for m in req["messages"]:
if m["role"] == "system":
Expand Down Expand Up @@ -141,7 +168,8 @@ def fillin_conv(ans):
nonlocal conv, message_id
if not conv.reference:
conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id
Expand All @@ -151,13 +179,13 @@ def stream():
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"

if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream")
Expand All @@ -184,33 +212,34 @@ def stream():
def tts():
req = request.json
text = req["text"]

tenants = TenantService.get_by_user_id(current_user.id)
if not tenants:
return get_data_error_result(retmsg="Tenant not found!")

tts_id = tenants[0]["tts_id"]
if not tts_id:
return get_data_error_result(retmsg="No default TTS model is set")

tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)

def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e)}},
ensure_ascii=False)).encode('utf-8')
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')

resp = Response(stream_audio(), mimetype="audio/mpeg")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")

return resp


@manager.route('/delete_msg', methods=['POST'])
@login_required
@validate_request("conversation_id", "message_id")
Expand All @@ -224,10 +253,10 @@ def delete_msg():
for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""):
continue
assert conv["message"][i+1]["id"] == req["message_id"]
assert conv["message"][i + 1]["id"] == req["message_id"]
conv["message"].pop(i)
conv["message"].pop(i)
conv["reference"].pop(max(0, i//2-1))
conv["reference"].pop(max(0, i // 2 - 1))
break

ConversationService.update_by_id(conv["id"], conv)
Expand Down
17 changes: 14 additions & 3 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
Expand Down Expand Up @@ -164,9 +165,19 @@ def list_dialogs():
@validate_request("dialog_ids")
def rm():
req = request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try:
DialogService.update_many_by_id(
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
for id in req["dialog_ids"]:
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
11 changes: 10 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from graphrag.mind_map_extractor import MindMapExtractor
from rag.app import naive
from rag.nlp import search
Expand Down Expand Up @@ -189,6 +189,15 @@ def list_docs():
if not kb_id:
return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "")

page_number = int(request.args.get("page", 1))
Expand Down
6 changes: 3 additions & 3 deletions api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def update():
def detail():
kb_id = request.args["kb_id"]
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
for m in tenants:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=m["tenant_id"], id=kb_id):
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
Expand Down