Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add versioning to the data point model #378

Merged
merged 19 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ async def has_node(self, node_id: str) -> bool:
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())

query = dedent("""MERGE (node {id: $node_id})
query = dedent(
"""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""")
RETURN ID(node) AS internal_id, node.id AS nodeId"""
)

params = {
"node_id": str(node.id),
Expand Down Expand Up @@ -182,13 +184,15 @@ async def add_edge(
):
serialized_properties = self.serialize_properties(edge_properties)

query = dedent("""MATCH (from_node {id: $from_node}),
query = dedent(
"""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
""")
"""
)

params = {
"from_node": str(from_node),
Expand Down
24 changes: 16 additions & 8 deletions cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,27 @@ async def create_data_point_query(self, data_point: DataPoint, vectorized_values
}
)

return dedent(f"""
return dedent(
f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
""").strip()
"""
).strip()

async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
properties = await self.stringify_properties(edge[3])
properties = f"{{{properties}}}"

return dedent(f"""
return dedent(
f"""
MERGE (source {{id:'{edge[0]}'}})
MERGE (target {{id: '{edge[1]}'}})
MERGE (source)-[edge:{edge[2]} {properties}]->(target)
ON MATCH SET edge.updated_at = timestamp()
ON CREATE SET edge.updated_at = timestamp()
""").strip()
"""
).strip()

async def create_collection(self, collection_name: str):
pass
Expand Down Expand Up @@ -195,12 +199,14 @@ async def add_edges(self, edges: list[tuple[str, str, str, dict]]):
self.query(query)

async def has_edges(self, edges):
query = dedent("""
query = dedent(
"""
UNWIND $edges AS edge
MATCH (a)-[r]->(b)
WHERE id(a) = edge.from_node AND id(b) = edge.to_node AND type(r) = edge.relationship_name
RETURN edge.from_node AS from_node, edge.to_node AS to_node, edge.relationship_name AS relationship_name, count(r) > 0 AS edge_exists
""").strip()
"""
).strip()

params = {
"edges": [
Expand Down Expand Up @@ -279,14 +285,16 @@ async def search(

[label, attribute_name] = collection_name.split(".")

query = dedent(f"""
query = dedent(
f"""
CALL db.idx.vector.queryNodes(
'{label}',
'{attribute_name}',
{limit},
vecf32({query_vector})
) YIELD node, score
""").strip()
"""
).strip()

result = self.query(query)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ async def get_schema_list(self) -> List[str]:
if self.engine.dialect.name == "postgresql":
async with self.engine.begin() as connection:
result = await connection.execute(
text("""
text(
"""
SELECT schema_name FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'pg_toast', 'information_schema');
""")
"""
)
)
return [schema[0] for schema in result.fetchall()]
return []
Expand Down
58 changes: 52 additions & 6 deletions cognee/infrastructure/engine/models/DataPoint.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Any, Dict
from uuid import UUID, uuid4

from pydantic import BaseModel, Field
from typing_extensions import TypedDict
import pickle


# Define metadata type
class MetaData(TypedDict):
index_fields: list[str]


# Updated DataPoint model with versioning and new fields
class DataPoint(BaseModel):
__tablename__ = "data_point"
id: UUID = Field(default_factory=uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc)
created_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
updated_at: int = Field(
default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000)
)
version: int = 1 # Default version
topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = {"index_fields": [], "type": "DataPoint"}

# class Config:
# underscore_attrs_are_private = True
# Override the Pydantic configuration
class Config:
underscore_attrs_are_private = True

@classmethod
Vasilije1990 marked this conversation as resolved.
Show resolved Hide resolved
def get_embeddable_data(self, data_point):
Expand All @@ -31,11 +41,11 @@ def get_embeddable_data(self, data_point):

if isinstance(attribute, str):
return attribute.strip()
else:
return attribute
return attribute

@classmethod
def get_embeddable_properties(self, data_point):
"""Retrieve all embeddable properties."""
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [
getattr(data_point, field, None) for field in data_point._metadata["index_fields"]
Expand All @@ -45,4 +55,40 @@ def get_embeddable_properties(self, data_point):

@classmethod
def get_embeddable_property_names(self, data_point):
"""Retrieve names of embeddable properties."""
return data_point._metadata["index_fields"] or []

def update_version(self):
"""Update the version and updated_at timestamp."""
self.version += 1
self.updated_at = int(datetime.now(timezone.utc).timestamp() * 1000)

# JSON Serialization
def to_json(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this serialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can parallelize tasks, since you had issues with that. Pickle or json

"""Serialize the instance to a JSON string."""
return self.json()

@classmethod
def from_json(self, json_str: str):
"""Deserialize the instance from a JSON string."""
return self.model_validate_json(json_str)

# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())

@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
Comment on lines +76 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Security concern: Replace pickle with a safer serialization method

Using pickle for serialization poses a security risk as it can execute arbitrary code during deserialization. Consider using a safer alternative like JSON or MessagePack.

-    def to_pickle(self) -> bytes:
-        """Serialize the instance to pickle-compatible bytes."""
-        return pickle.dumps(self.dict())
-
-    @classmethod
-    def from_pickle(self, pickled_data: bytes):
-        """Deserialize the instance from pickled bytes."""
-        data = pickle.loads(pickled_data)
-        return self(**data)
+    def to_bytes(self) -> bytes:
+        """Serialize the instance to bytes using JSON."""
+        return self.json().encode('utf-8')
+
+    @classmethod
+    def from_bytes(cls, data: bytes):
+        """Deserialize the instance from JSON bytes."""
+        return cls.parse_raw(data)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Pickle Serialization
def to_pickle(self) -> bytes:
"""Serialize the instance to pickle-compatible bytes."""
return pickle.dumps(self.dict())
@classmethod
def from_pickle(self, pickled_data: bytes):
"""Deserialize the instance from pickled bytes."""
data = pickle.loads(pickled_data)
return self(**data)
def to_bytes(self) -> bytes:
"""Serialize the instance to bytes using JSON."""
return self.json().encode('utf-8')
@classmethod
def from_bytes(cls, data: bytes):
"""Deserialize the instance from JSON bytes."""
return cls.parse_raw(data)


def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Serialize model to a dictionary."""
return self.model_dump(**kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DataPoint":
"""Deserialize model from a dictionary."""
return cls.model_validate(data)
6 changes: 4 additions & 2 deletions cognee/tasks/temporal_awareness/index_graphiti_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ async def index_and_transform_graphiti_nodes_and_edges():
raise RuntimeError("Initialization error") from e

await graph_engine.query("""MATCH (n) SET n.id = n.uuid RETURN n""")
await graph_engine.query("""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
await graph_engine.query(
"""MATCH (source)-[r]->(target) SET r.source_node_id = source.id,
r.target_node_id = target.id,
r.relationship_name = type(r) RETURN r""")
r.relationship_name = type(r) RETURN r"""
)
await graph_engine.query("""MATCH (n) SET n.text = COALESCE(n.summary, n.content) RETURN n""")

nodes_data, edges_data = await graph_engine.get_model_independent_graph_data()
Expand Down
18 changes: 9 additions & 9 deletions cognee/tests/integration/documents/AudioDocument_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def test_AudioDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
18 changes: 9 additions & 9 deletions cognee/tests/integration/documents/ImageDocument_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def test_ImageDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
18 changes: 9 additions & 9 deletions cognee/tests/integration/documents/PdfDocument_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def test_PdfDocument():
for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
18 changes: 9 additions & 9 deletions cognee/tests/integration/documents/TextDocument_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_TextDocument(input_file, chunk_size):
for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker")
):
assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
)
assert ground_truth["len_text"] == len(paragraph_data.text), (
f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
)
assert ground_truth["cut_type"] == paragraph_data.cut_type, (
f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
)
assert (
ground_truth["word_count"] == paragraph_data.word_count
), f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'
assert ground_truth["len_text"] == len(
paragraph_data.text
), f'{ground_truth["len_text"] = } != {len(paragraph_data.text) = }'
assert (
ground_truth["cut_type"] == paragraph_data.cut_type
), f'{ground_truth["cut_type"] = } != {paragraph_data.cut_type = }'
30 changes: 15 additions & 15 deletions cognee/tests/integration/documents/UnstructuredDocument_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,32 @@ def test_UnstructuredDocument():
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"

# Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, (
f" sentence_end != {paragraph_data.cut_type = }"
)
assert (
"sentence_end" == paragraph_data.cut_type
), f" sentence_end != {paragraph_data.cut_type = }"

# TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}"
)
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"A A A A A A A A A,A A A A A A,A A" == paragraph_data.text
), f"Read text doesn't match expected text: {paragraph_data.text}"
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"

# Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, (
f" sentence_cut != {paragraph_data.cut_type = }"
)
assert (
"sentence_cut" == paragraph_data.cut_type
), f" sentence_cut != {paragraph_data.cut_type = }"
12 changes: 6 additions & 6 deletions cognee/tests/test_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ async def test_deduplication():

result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert result[0]["name"] == "Natural_language_processing_copy", (
"Result name does not match expected value."
)
assert (
result[0]["name"] == "Natural_language_processing_copy"
), "Result name does not match expected value."

result = await relational_engine.get_all_data_from_table("datasets")
assert len(result) == 2, "Unexpected number of datasets found."
Expand Down Expand Up @@ -61,9 +61,9 @@ async def test_deduplication():

result = await relational_engine.get_all_data_from_table("data")
assert len(result) == 1, "More than one data entity was found."
assert hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"], (
"Content hash is not a part of file name."
)
assert (
hashlib.md5(text.encode("utf-8")).hexdigest() in result[0]["name"]
), "Content hash is not a part of file name."

await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
Expand Down
6 changes: 3 additions & 3 deletions cognee/tests/test_falkordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ async def main():

from cognee.infrastructure.databases.relational import get_relational_engine

assert not os.path.exists(get_relational_engine().db_path), (
"SQLite relational database is not empty"
)
assert not os.path.exists(
get_relational_engine().db_path
), "SQLite relational database is not empty"

from cognee.infrastructure.databases.graph import get_graph_config

Expand Down
Loading
Loading