Skip to content

Commit

Permalink
add owner check for team work (#2892)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#2834

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Oct 18, 2024
1 parent 8fdfa0f commit c760f05
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 13 deletions.
3 changes: 3 additions & 0 deletions agent/component/exesql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def check(self):
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")


class ExeSQL(ComponentBase, ABC):
Expand Down
45 changes: 44 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,17 @@ def list_docs():


@manager.route('/infos', methods=['POST'])
@login_required
def docinfos():
req = request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))

Expand Down Expand Up @@ -242,11 +250,17 @@ def thumbnails():
def change_status():
req = request.json
if str(req["status"]) not in ["0", "1"]:
get_json_result(
return get_json_result(
data=False,
retmsg='"Status" must be either 0 or 1!',
retcode=RetCode.ARGUMENT_ERROR)

if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR)

try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -285,6 +299,15 @@ def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]

for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)

root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
Expand Down Expand Up @@ -323,6 +346,13 @@ def rm():
@validate_request("doc_ids", "run")
def run():
req = request.json
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
Expand Down Expand Up @@ -356,6 +386,12 @@ def run():
@validate_request("doc_id", "name")
def rename():
req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -416,6 +452,13 @@ def get(doc_id):
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json

if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down
23 changes: 16 additions & 7 deletions api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user

Expand All @@ -23,14 +22,12 @@
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid, get_format_time
from api.db import StatusEnum, UserTenantRole, FileSource
from api.utils import get_uuid
from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import Knowledgebase, File
from api.settings import stat_logger, RetCode
from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH


@manager.route('/create', methods=['post'])
Expand Down Expand Up @@ -65,6 +62,12 @@ def create():
def update():
req = request.json
req["name"] = req["name"].strip()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]):
Expand Down Expand Up @@ -139,6 +142,12 @@ def list_kbs():
@validate_request("kb_id")
def rm():
req = request.json
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"])
Expand Down
29 changes: 28 additions & 1 deletion api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from rag.nlp import search, rag_tokenizer

from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task
from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
Expand Down Expand Up @@ -263,6 +263,33 @@ def get_tenant_id_by_name(cls, name):
return
return docs[0]["tenant_id"]

@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
Expand Down
24 changes: 23 additions & 1 deletion api/db/services/knowledgebase_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
from api.db import StatusEnum, TenantPermission
from api.db.db_models import Knowledgebase, DB, Tenant, User
from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant
from api.db.services.common_service import CommonService


Expand Down Expand Up @@ -182,3 +182,25 @@ def get_list(cls, joined_tenant_ids, user_id,
kbs = kbs.paginate(page_number, items_per_page)

return list(kbs.dicts())

@classmethod
@DB.connection_context()
def accessible(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def accessible4deletion(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

6 changes: 3 additions & 3 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import os
from datetime import date
from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
Expand Down Expand Up @@ -143,9 +144,8 @@

SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get(
"secret_key",
"infiniflow")
{}).get("secret_key", str(date.today()))

TOKEN_EXPIRE_IN = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"token_expires_in", 3600)
Expand Down

0 comments on commit c760f05

Please sign in to comment.