Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example task extraction #127

Merged
merged 17 commits into from
Aug 8, 2024
41 changes: 25 additions & 16 deletions cognee/api/v1/cognify/cognify_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from cognee.modules.data.processing.document_types.ImageDocument import ImageDocument
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.data.processing.document_types import PdfDocument, TextDocument
from cognee.modules.cognify.vector import save_data_chunks
from cognee.modules.data.processing.process_documents import process_documents
from cognee.modules.classification.classify_text_chunks import classify_text_chunks
from cognee.modules.data.extraction.data_summary.summarize_text_chunks import summarize_text_chunks
from cognee.modules.data.processing.filter_affected_chunks import filter_affected_chunks
from cognee.modules.data.processing.remove_obsolete_chunks import remove_obsolete_chunks
from cognee.modules.data.extraction.knowledge_graph.expand_knowledge_graph import expand_knowledge_graph
from cognee.modules.data.extraction.knowledge_graph.establish_graph_topology import establish_graph_topology
# from cognee.modules.cognify.vector import save_data_chunks
# from cognee.modules.data.processing.process_documents import process_documents
# from cognee.modules.classification.classify_text_chunks import classify_text_chunks
# from cognee.modules.data.extraction.data_summary.summarize_text_chunks import summarize_text_chunks
# from cognee.modules.data.processing.filter_affected_chunks import filter_affected_chunks
# from cognee.modules.data.processing.remove_obsolete_chunks import remove_obsolete_chunks
# from cognee.modules.data.extraction.knowledge_graph.expand_knowledge_graph import expand_knowledge_graph
# from cognee.modules.data.extraction.knowledge_graph.establish_graph_topology import establish_graph_topology
from cognee.modules.data.models import Dataset, Data
from cognee.modules.data.operations.get_dataset_data import get_dataset_data
from cognee.modules.data.operations.retrieve_datasets import retrieve_datasets
Expand All @@ -27,6 +27,15 @@
from cognee.modules.users.permissions.methods import check_permissions_on_documents
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.pipelines.operations.log_pipeline_status import log_pipeline_status
from cognee.tasks.chunk_extract_summary.chunk_extract_summary import chunk_extract_summary_task
from cognee.tasks.chunk_naive_llm_classifier.chunk_naive_llm_classifier import chunk_naive_llm_classifier_task
from cognee.tasks.chunk_remove_disconnected.chunk_remove_disconnected import chunk_remove_disconnected_task
from cognee.tasks.chunk_to_graph_decomposition.chunk_to_graph_decomposition import chunk_to_graph_decomposition_task
from cognee.tasks.chunk_to_vector_graphstore.chunk_to_vector_graphstore import chunk_to_vector_graphstore_task
from cognee.tasks.chunk_update_check.chunk_update_check import chunk_update_check_task
from cognee.tasks.graph_decomposition_to_graph_nodes.graph_decomposition_to_graph_nodes import \
graph_decomposition_to_graph_nodes_task
from cognee.tasks.source_documents_to_chunks.source_documents_to_chunks import source_documents_to_chunks

logger = logging.getLogger("cognify.v2")

Expand Down Expand Up @@ -100,26 +109,26 @@ async def run_cognify_pipeline(dataset: Dataset):
root_node_id = "ROOT"

tasks = [
Task(process_documents, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(establish_graph_topology, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Task(expand_knowledge_graph, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(filter_affected_chunks, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(source_documents_to_chunks, parent_node_id = root_node_id), # Classify documents and save them as a nodes in graph db, extract text chunks based on the document type
Task(chunk_to_graph_decomposition_task, topology_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Set the graph topology for the document chunk data
Vasilije1990 marked this conversation as resolved.
Show resolved Hide resolved
Task(graph_decomposition_to_graph_nodes_task, graph_model = KnowledgeGraph, collection_name = "entities"), # Generate knowledge graphs from the document chunks and attach it to chunk nodes
Task(chunk_update_check_task, collection_name = "chunks"), # Find all affected chunks, so we don't process unchanged chunks
Task(
save_data_chunks,
chunk_to_vector_graphstore_task,
collection_name = "chunks",
), # Save the document chunks in vector db and as nodes in graph db (connected to the document node and between each other)
run_tasks_parallel([
Task(
summarize_text_chunks,
chunk_extract_summary_task,
summarization_model = cognee_config.summarization_model,
collection_name = "chunk_summaries",
), # Summarize the document chunks
Task(
classify_text_chunks,
chunk_naive_llm_classifier_task,
classification_model = cognee_config.classification_model,
),
]),
Task(remove_obsolete_chunks), # Remove the obsolete document chunks.
Task(chunk_remove_disconnected_task), # Remove the obsolete document chunks.
]

pipeline = run_tasks(tasks, documents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def create_table(self, schema_name: str, table_name: str, table_config: li

async def delete_table(self, table_name: str):
async with self.engine.connect() as connection:
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name};"))
await connection.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE;"))

async def insert_data(self, schema_name: str, table_name: str, data: list[dict]):
columns = ", ".join(data[0].keys())
Expand Down Expand Up @@ -101,9 +101,10 @@ async def drop_tables(self, connection):
async def delete_database(self):
async with self.engine.connect() as connection:
try:
async with self.engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)

async with connection.begin() as trans:
for table in Base.metadata.sorted_tables:
drop_table_query = text(f'DROP TABLE IF EXISTS {table.name} CASCADE')
await connection.execute(drop_table_query)
print("Database deleted successfully.")
except Exception as e:
print(f"Error deleting database: {e}")
print(f"Error deleting database: {e}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! But remove the unused variable trans.

The changes improve the accuracy and reliability of the database deletion process. However, the variable trans is assigned but never used.

-                async with connection.begin() as trans:
+                async with connection.begin():
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async with connection.begin() as trans:
for table in Base.metadata.sorted_tables:
drop_table_query = text(f'DROP TABLE IF EXISTS {table.name} CASCADE')
await connection.execute(drop_table_query)
print("Database deleted successfully.")
except Exception as e:
print(f"Error deleting database: {e}")
print(f"Error deleting database: {e}")
async with connection.begin():
for table in Base.metadata.sorted_tables:
drop_table_query = text(f'DROP TABLE IF EXISTS {table.name} CASCADE')
await connection.execute(drop_table_query)
print("Database deleted successfully.")
except Exception as e:
print(f"Error deleting database: {e}")
Tools
Ruff

104-104: Local variable trans is assigned to but never used

Remove assignment to unused variable trans

(F841)

2 changes: 1 addition & 1 deletion cognee/modules/classification/classify_text_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from ..data.extraction.extract_categories import extract_categories

async def classify_text_chunks(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks
Comment on lines +11 to 13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling for empty data chunks.

Consider logging a warning or raising an exception if data_chunks is empty to ensure proper handling.

-    if len(data_chunks) == 0:
-        return data_chunks
+    if len(data_chunks) == 0:
+        raise ValueError("data_chunks list is empty")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks
async def chunk_naive_llm_classifier(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
if len(data_chunks) == 0:
raise ValueError("data_chunks list is empty")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...processing.chunk_types.DocumentChunk import DocumentChunk
from .add_model_class_to_graph import add_model_class_to_graph

async def establish_graph_topology(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
async def chunk_to_graph_decomposition(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
if topology_model == KnowledgeGraph:
return data_chunks

Expand Down
38 changes: 38 additions & 0 deletions cognee/tasks/chunk_extract_summary/chunk_extract_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

import asyncio
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.modules.data.extraction.data_summary.models.TextSummary import TextSummary
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk


async def chunk_extract_summary_task(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel], collection_name: str = "summaries"):
if len(data_chunks) == 0:
return data_chunks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling for external calls.

The function makes several asynchronous calls. It's important to handle potential errors that might arise from these calls.

+    try:
+        if len(data_chunks) == 0:
+            return data_chunks
+    except Exception as e:
+        # Handle the exception or log it
+        raise e

Add similar error handling for other asynchronous calls within the function.

Committable suggestion was skipped due to low confidence.


chunk_summaries = await asyncio.gather(
*[extract_summary(chunk.text, summarization_model) for chunk in data_chunks]
)

vector_engine = get_vector_engine()

await vector_engine.create_collection(collection_name, payload_schema=TextSummary)

await vector_engine.create_data_points(
collection_name,
[
DataPoint[TextSummary](
id = str(chunk.chunk_id),
payload = dict(
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
text = chunk_summaries[chunk_index].summary,
),
embed_field = "text",
) for (chunk_index, chunk) in enumerate(data_chunks)
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling for data point creation.

The function makes several asynchronous calls to create data points. It's important to handle potential errors that might arise from these calls.

+    try:
+        await vector_engine.create_data_points(
+            collection_name,
+            [
+                DataPoint[TextSummary](
+                    id=str(chunk.chunk_id),
+                    payload=dict(
+                        chunk_id=str(chunk.chunk_id),
+                        document_id=str(chunk.document_id),
+                        text=chunk_summaries[chunk_index].summary,
+                    ),
+                    embed_field="text",
+                ) for (chunk_index, chunk) in enumerate(data_chunks)
+            ],
+        )
+    except Exception as e:
+        # Handle the exception or log it
+        raise e
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
await vector_engine.create_data_points(
collection_name,
[
DataPoint[TextSummary](
id = str(chunk.chunk_id),
payload = dict(
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
text = chunk_summaries[chunk_index].summary,
),
embed_field = "text",
) for (chunk_index, chunk) in enumerate(data_chunks)
],
)
try:
await vector_engine.create_data_points(
collection_name,
[
DataPoint[TextSummary](
id = str(chunk.chunk_id),
payload = dict(
chunk_id = str(chunk.chunk_id),
document_id = str(chunk.document_id),
text = chunk_summaries[chunk_index].summary,
),
embed_field = "text",
) for (chunk_index, chunk) in enumerate(data_chunks)
],
)
except Exception as e:
# Handle the exception or log it
raise e


return data_chunks
152 changes: 152 additions & 0 deletions cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from uuid import uuid5, NAMESPACE_OID
from typing import Type
from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine, DataPoint
from cognee.modules.data.extraction.extract_categories import extract_categories
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk


async def chunk_naive_llm_classifier_task(data_chunks: list[DocumentChunk], classification_model: Type[BaseModel]):
if len(data_chunks) == 0:
return data_chunks

chunk_classifications = await asyncio.gather(
*[extract_categories(chunk.text, classification_model) for chunk in data_chunks],
)

classification_data_points = []

for chunk_index, chunk in enumerate(data_chunks):
chunk_classification = chunk_classifications[chunk_index]
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))
classification_data_points.append(uuid5(NAMESPACE_OID, chunk_classification.label.type))

for classification_subclass in chunk_classification.label.subclass:
classification_data_points.append(uuid5(NAMESPACE_OID, classification_subclass.value))

vector_engine = get_vector_engine()

class Keyword(BaseModel):
uuid: str
text: str
chunk_id: str
document_id: str

collection_name = "classification"

if await vector_engine.has_collection(collection_name):
existing_data_points = await vector_engine.retrieve(
collection_name,
list(set(classification_data_points)),
) if len(classification_data_points) > 0 else []

existing_points_map = {point.id: True for point in existing_data_points}
else:
existing_points_map = {}
await vector_engine.create_collection(collection_name, payload_schema=Keyword)

data_points = []
nodes = []
edges = []

for (chunk_index, data_chunk) in enumerate(data_chunks):
chunk_classification = chunk_classifications[chunk_index]
classification_type_label = chunk_classification.label.type
classification_type_id = uuid5(NAMESPACE_OID, classification_type_label)

if classification_type_id not in existing_points_map:
data_points.append(
DataPoint[Keyword](
id=str(classification_type_id),
payload=Keyword.parse_obj({
"uuid": str(classification_type_id),
"text": classification_type_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field="text",
)
)

nodes.append((
str(classification_type_id),
dict(
id=str(classification_type_id),
name=classification_type_label,
type=classification_type_label,
)
))
existing_points_map[classification_type_id] = True

edges.append((
str(data_chunk.chunk_id),
str(classification_type_id),
"is_media_type",
dict(
relationship_name="is_media_type",
source_node_id=str(data_chunk.chunk_id),
target_node_id=str(classification_type_id),
),
))

for classification_subclass in chunk_classification.label.subclass:
classification_subtype_label = classification_subclass.value
classification_subtype_id = uuid5(NAMESPACE_OID, classification_subtype_label)

if classification_subtype_id not in existing_points_map:
data_points.append(
DataPoint[Keyword](
id=str(classification_subtype_id),
payload=Keyword.parse_obj({
"uuid": str(classification_subtype_id),
"text": classification_subtype_label,
"chunk_id": str(data_chunk.chunk_id),
"document_id": str(data_chunk.document_id),
}),
embed_field="text",
)
)

nodes.append((
str(classification_subtype_id),
dict(
id=str(classification_subtype_id),
name=classification_subtype_label,
type=classification_subtype_label,
)
))
edges.append((
str(classification_subtype_id),
str(classification_type_id),
"is_subtype_of",
dict(
relationship_name="contains",
source_node_id=str(classification_type_id),
target_node_id=str(classification_subtype_id),
),
))

existing_points_map[classification_subtype_id] = True

edges.append((
str(data_chunk.chunk_id),
str(classification_subtype_id),
"is_classified_as",
dict(
relationship_name="is_classified_as",
source_node_id=str(data_chunk.chunk_id),
target_node_id=str(classification_subtype_id),
),
))

if len(nodes) > 0 or len(edges) > 0:
await vector_engine.create_data_points(collection_name, data_points)

graph_engine = await get_graph_engine()

await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)

return data_chunks
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk


# from cognee.infrastructure.databases.vector import get_vector_engine


async def chunk_remove_disconnected_task(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
graph_engine = await get_graph_engine()

document_ids = set((data_chunk.document_id for data_chunk in data_chunks))

obsolete_chunk_ids = []

for document_id in document_ids:
chunk_ids = await graph_engine.get_successor_ids(document_id, edge_label = "has_chunk")

for chunk_id in chunk_ids:
previous_chunks = await graph_engine.get_predecessor_ids(chunk_id, edge_label = "next_chunk")

if len(previous_chunks) == 0:
obsolete_chunk_ids.append(chunk_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider breaking down nested loops for readability.

Nested loops can be hard to read and maintain. Consider breaking them down into smaller functions.

async def get_obsolete_chunk_ids(graph_engine, document_ids):
    obsolete_chunk_ids = []
    for document_id in document_ids:
        chunk_ids = await graph_engine.get_successor_ids(document_id, edge_label="has_chunk")
        for chunk_id in chunk_ids:
            previous_chunks = await graph_engine.get_predecessor_ids(chunk_id, edge_label="next_chunk")
            if len(previous_chunks) == 0:
                obsolete_chunk_ids.append(chunk_id)
    return obsolete_chunk_ids

async def chunk_remove_disconnected_task(data_chunks: list[DocumentChunk]) -> list[DocumentChunk]:
    graph_engine = await get_graph_engine()
    document_ids = {data_chunk.document_id for data_chunk in data_chunks}
    obsolete_chunk_ids = await get_obsolete_chunk_ids(graph_engine, document_ids)
    if obsolete_chunk_ids:
        await graph_engine.delete_nodes(obsolete_chunk_ids)
    disconnected_nodes = await graph_engine.get_disconnected_nodes()
    if disconnected_nodes:
        await graph_engine.delete_nodes(disconnected_nodes)
    return data_chunks

if len(obsolete_chunk_ids) > 0:
await graph_engine.delete_nodes(obsolete_chunk_ids)

disconnected_nodes = await graph_engine.get_disconnected_nodes()
if len(disconnected_nodes) > 0:
await graph_engine.delete_nodes(disconnected_nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the deletion of disconnected nodes.

The deletion of disconnected nodes can be optimized by combining the checks and deletions.

-    disconnected_nodes = await graph_engine.get_disconnected_nodes()
-    if len(disconnected_nodes) > 0:
-        await graph_engine.delete_nodes(disconnected_nodes)
+    disconnected_nodes = await graph_engine.get_disconnected_nodes()
+    if disconnected_nodes:
+        await graph_engine.delete_nodes(disconnected_nodes)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
disconnected_nodes = await graph_engine.get_disconnected_nodes()
if len(disconnected_nodes) > 0:
await graph_engine.delete_nodes(disconnected_nodes)
disconnected_nodes = await graph_engine.get_disconnected_nodes()
if disconnected_nodes:
await graph_engine.delete_nodes(disconnected_nodes)


return data_chunks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling for external calls.

The function makes several asynchronous calls to graph_engine. It's important to handle potential errors that might arise from these calls.

+    try:
+        graph_engine = await get_graph_engine()
+    except Exception as e:
+        # Handle the exception or log it
+        raise e

Add similar error handling for other asynchronous calls within the function.

Committable suggestion was skipped due to low confidence.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Type
from pydantic import BaseModel

from cognee.modules.data.processing.chunk_types.DocumentChunk import DocumentChunk
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph.add_model_class_to_graph import add_model_class_to_graph


async def chunk_to_graph_decomposition_task(data_chunks: list[DocumentChunk], topology_model: Type[BaseModel]):
if topology_model == KnowledgeGraph:
return data_chunks

graph_engine = await get_graph_engine()

await add_model_class_to_graph(topology_model, graph_engine)

return data_chunks


def generate_node_id(node_id: str) -> str:
return node_id.upper().replace(" ", "_").replace("'", "")
Loading
Loading