Skip to content

Commit

Permalink
Fixes to the ACL model
Browse files Browse the repository at this point in the history
  • Loading branch information
Vasilije1990 committed Jul 27, 2024
1 parent 218d322 commit b4d1a73
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 75 deletions.
13 changes: 11 additions & 2 deletions cognee/api/v1/add/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import asyncio
import dlt
import duckdb
from fastapi_users import fastapi_users

import cognee.modules.ingestion as ingestion
from cognee.infrastructure.databases.relational.user_authentication.users import give_permission_document, \
get_async_session_context, current_active_user, create_default_user
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.ingestion import get_matched_datasets, save_data_to_file
from cognee.shared.utils import send_telemetry
Expand Down Expand Up @@ -48,7 +52,7 @@ async def add(data: Union[BinaryIO, List[BinaryIO], str, List[str]], dataset_nam

return []

async def add_files(file_paths: List[str], dataset_name: str):
async def add_files(file_paths: List[str], dataset_name: str, user_id: str = "default_user"):
base_config = get_base_config()
data_directory_path = base_config.data_root_directory

Expand Down Expand Up @@ -82,12 +86,17 @@ async def add_files(file_paths: List[str], dataset_name: str):
)

@dlt.resource(standalone = True, merge_key = "id")
def data_resources(file_paths: str):
def data_resources(file_paths: str, user_id: str = user_id):
for file_path in file_paths:
with open(file_path.replace("file://", ""), mode = "rb") as file:
classified_data = ingestion.classify(file)

data_id = ingestion.identify(classified_data)
async with get_async_session_context() as session:
if user_id is None:
current_active_user = create_default_user()

give_permission_document(current_active_user, data_id, "write", session= session)

file_metadata = classified_data.get_metadata()

Expand Down
146 changes: 76 additions & 70 deletions cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import uuid
from typing import Union

from fastapi_users import fastapi_users
from sqlalchemy.ext.asyncio import AsyncSession

from cognee.infrastructure.databases.graph import get_graph_config
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
from cognee.infrastructure.databases.relational.user_authentication.users import has_permission_document, \
get_user_permissions, get_async_session_context
# from cognee.infrastructure.databases.relational.user_authentication.authentication_db import async_session_maker
# from cognee.infrastructure.databases.relational.user_authentication.users import get_user_permissions, fastapi_users
from cognee.modules.cognify.config import get_cognify_config
Expand Down Expand Up @@ -37,14 +41,6 @@ def __init__(self, message: str):
super().__init__(self.message)

async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = None, user_id:str="default_user"):
# session: AsyncSession = async_session_maker()
# user = await fastapi_users.get_user_manager.get(user_id)
# user_permissions = await get_user_permissions(user, session)
# hash_object = hashlib.sha256(user.encode())
# hashed_user_id = hash_object.hexdigest()
# required_permission = "write"
# if required_permission not in user_permissions:
# raise PermissionDeniedException("Not enough permissions")

relational_config = get_relationaldb_config()
db_engine = relational_config.database_engine
Expand All @@ -55,68 +51,78 @@ async def cognify(datasets: Union[str, list[str]] = None, root_node_id: str = No


async def run_cognify_pipeline(dataset_name: str, files: list[dict]):
async with update_status_lock:
task_status = get_task_status([dataset_name])

if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
logger.info(f"Dataset {dataset_name} is being processed.")
return

update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
try:
cognee_config = get_cognify_config()
graph_config = get_graph_config()
root_node_id = None

if graph_config.infer_graph_topology and graph_config.graph_topology_task:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
root_node_id = await topology_engine.add_graph_topology(files = files)
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
await topology_engine.add_graph_topology(graph_config.topology_file_path)
elif not graph_config.graph_topology_task:
root_node_id = "ROOT"

tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,
collection_name = "chunks",
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
run_tasks_parallel([
Task(
summarize_text_chunks,
summarization_model = cognee_config.summarization_model,
collection_name = "chunk_summaries",
), # Summarize the document chunks
Task(
classify_text_chunks,
classification_model = cognee_config.classification_model,
),
]),
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
]

pipeline = run_tasks(tasks, [
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
for file in files
])

async for result in pipeline:
print(result)

update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
except Exception as error:
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
raise error

for file in files:
file["id"] = str(uuid.uuid4())
file["name"] = file["name"].replace(" ", "_")

async with get_async_session_context() as session:

out = await has_permission_document(user_id, file["id"], "write", session)


async with update_status_lock:
task_status = get_task_status([dataset_name])

if dataset_name in task_status and task_status[dataset_name] == "DATASET_PROCESSING_STARTED":
logger.info(f"Dataset {dataset_name} is being processed.")
return

update_task_status(dataset_name, "DATASET_PROCESSING_STARTED")
try:
cognee_config = get_cognify_config()
graph_config = get_graph_config()
root_node_id = None

if graph_config.infer_graph_topology and graph_config.graph_topology_task:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
root_node_id = await topology_engine.add_graph_topology(files = files)
elif graph_config.infer_graph_topology and not graph_config.infer_graph_topology:
from cognee.modules.topology.topology import TopologyEngine
topology_engine = TopologyEngine(infer=graph_config.infer_graph_topology)
await topology_engine.add_graph_topology(graph_config.topology_file_path)
elif not graph_config.graph_topology_task:
root_node_id = "ROOT"

tasks = [
Task(process_documents, parent_node_id = root_node_id, task_config = { "batch_size": 10 }, user_id = hashed_user_id, user_permissions=user_permissions), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,
collection_name = "chunks",
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
run_tasks_parallel([
Task(
summarize_text_chunks,
summarization_model = cognee_config.summarization_model,
collection_name = "chunk_summaries",
), # Summarize the document chunks
Task(
classify_text_chunks,
classification_model = cognee_config.classification_model,
),
]),
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
]

pipeline = run_tasks(tasks, [
PdfDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "pdf" else
AudioDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "audio" else
ImageDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"]) if file["extension"] == "image" else
TextDocument(title=f"{file['name']}.{file['extension']}", file_path=file["file_path"])
for file in files
])

async for result in pipeline:
print(result)

update_task_status(dataset_name, "DATASET_PROCESSING_FINISHED")
except Exception as error:
update_task_status(dataset_name, "DATASET_PROCESSING_ERROR")
raise error


existing_datasets = db_engine.get_datasets()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from cognee.infrastructure.databases.relational.user_authentication.authentication_db import User, get_user_db, \
get_async_session
get_async_session, ACL
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.authentication import JWTStrategy
from cognee.infrastructure.databases.relational.user_authentication.schemas import UserRead, UserCreate
Expand Down Expand Up @@ -86,13 +87,13 @@ async def hash_password(password: str) -> str:
get_user_db_context = asynccontextmanager(get_user_db)
get_user_manager_context = asynccontextmanager(get_user_manager)

async def create_user_method(email: str, password: str, is_superuser: bool = False):
async def create_user_method(email: str, password: str, is_superuser: bool = False, is_active: bool = True):
try:
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
user = await user_manager.create(
UserCreate(email=email, password=password, is_superuser=is_superuser)
UserCreate(email=email, password=password, is_superuser=is_superuser, is_active=is_active)
)
print(f"User created: {user.email}")
except UserAlreadyExists:
Expand Down Expand Up @@ -175,3 +176,65 @@ async def user_check_token(token: str) -> bool:
except:
return False

async def has_permission_document(user: User, document_id: str, permission: str, session: AsyncSession) -> bool:
# Check if the user has the specified permission for the document
acl_entry = await session.execute(
"""
SELECT 1 FROM acls
WHERE user_id = :user_id AND document_id = :document_id AND permission = :permission
""",
{'user_id': str(user.id), 'document_id': str(document_id), 'permission': permission}
)
if acl_entry.scalar_one_or_none():
return True

# Check if any of the user's groups have the specified permission for the document
group_acl_entry = await session.execute(
"""
SELECT 1 FROM acls
JOIN user_group ON acls.group_id = user_group.group_id
WHERE user_group.user_id = :user_id AND acls.document_id = :document_id AND acls.permission = :permission
""",
{'user_id': str(user.id), 'document_id': str(document_id), 'permission': permission}
)
if group_acl_entry.scalar_one_or_none():
return True

return False

async def create_default_user():
async with get_async_session_context() as session:
default_user_email = "[email protected]"
default_user_password = "default_password"

user = await create_user_method(
email=default_user_email,
password=await hash_password(default_user_password),
is_superuser=True,
is_active=True)
session.add(user)
out = await session.commit()
await session.refresh(user)
return out.id

async def give_permission_document(user: Optional[User], document_id: str, permission: str,
session: AsyncSession):

acl_entry = ACL(
document_id=document_id,
user_id=user.id,
permission=permission
)
session.add(acl_entry)
await session.commit()


if user.is_superuser:
permission = 'all_permissions' # Example permission, change as needed
acl_entry = ACL(
document_id=document_id,
user_id=user.id,
permission=permission
)
session.add(acl_entry)
await session.commit()

0 comments on commit b4d1a73

Please sign in to comment.