Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feat/COG-553-graph-me…
Browse files Browse the repository at this point in the history
…mory-projection
  • Loading branch information
hajdul88 committed Nov 11, 2024
2 parents 38d29ee + 52180eb commit 3e7df33
Show file tree
Hide file tree
Showing 29 changed files with 277 additions and 212 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_neo4j.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand All @@ -22,7 +22,7 @@ jobs:
run_neo4j_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest

defaults:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]


concurrency:
Expand All @@ -23,7 +23,7 @@ jobs:
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_pgvector.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]


concurrency:
Expand All @@ -23,7 +23,7 @@ jobs:
run_pgvector_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/test_python_3_10.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name: test | python 3.10

on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand All @@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/test_python_3_11.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name: test | python 3.11

on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand All @@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/test_python_3_9.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name: test | python 3.9

on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand All @@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_qdrant.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]


concurrency:
Expand All @@ -23,7 +23,7 @@ jobs:
run_qdrant_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest

defaults:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_weaviate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]


concurrency:
Expand All @@ -23,7 +23,7 @@ jobs:
run_weaviate_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest

defaults:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,6 @@ cognee/cache/
# Default cognee system directory, used in development
.cognee_system/
.data_storage/
.anon_id

node_modules/
24 changes: 17 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,34 @@ import asyncio
from cognee.api.v1.search import SearchType

async def main():
await cognee.prune.prune_data() # Reset cognee data
await cognee.prune.prune_system(metadata=True) # Reset cognee system state
# Reset cognee data
await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True)

text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""

await cognee.add(text) # Add text to cognee
await cognee.cognify() # Use LLMs and cognee to create knowledge graph
# Add text to cognee
await cognee.add(text)

search_results = await cognee.search( # Search cognee for insights
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()

# Search cognee for insights
search_results = await cognee.search(
SearchType.INSIGHTS,
{'query': 'Tell me about NLP'}
"Tell me about NLP",
)

for result_text in search_results: # Display results
# Display results
for result_text in search_results:
print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval

asyncio.run(main())
```
Expand Down
56 changes: 32 additions & 24 deletions cognee/infrastructure/databases/graph/neo4j_driver/adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Neo4j Adapter for Graph Database"""
import logging
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from uuid import UUID
Expand All @@ -18,14 +19,17 @@ def __init__(
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
driver: Optional[Any] = None
driver: Optional[Any] = None,
):
self.driver = driver or AsyncGraphDatabase.driver(
graph_database_url,
auth = (graph_database_username, graph_database_password),
max_connection_lifetime = 120
)

async def close(self) -> None:
await self.driver.close()

@asynccontextmanager
async def get_session(self) -> AsyncSession:
async with self.driver.session() as session:
Expand Down Expand Up @@ -59,11 +63,10 @@ async def has_node(self, node_id: str) -> bool:
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())

query = """MERGE (node {id: $node_id})
ON CREATE SET node += $properties
ON MATCH SET node += $properties
ON MATCH SET node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId"""
query = dedent("""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""")

params = {
"node_id": str(node.id),
Expand All @@ -76,9 +79,8 @@ async def add_nodes(self, nodes: list[DataPoint]) -> None:
query = """
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties
ON MATCH SET n += node.properties
ON MATCH SET n.updated_at = timestamp()
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
Expand Down Expand Up @@ -133,12 +135,19 @@ async def delete_nodes(self, node_ids: list[str]) -> None:
return await self.query(query, params)

async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`)
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists
"""

edge_exists = await self.query(query)
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}

edge_exists = await self.query(query, params)
return edge_exists

async def has_edges(self, edges):
Expand All @@ -165,22 +174,21 @@ async def has_edges(self, edges):
raise error


async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")

query = f"""MATCH (from_node:`{str(from_node)}`
{{id: $from_node}}),
(to_node:`{str(to_node)}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r"""
query = dedent("""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
""")

params = {
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties
}

Expand Down Expand Up @@ -347,8 +355,8 @@ async def get_connections(self, node_id: UUID) -> list:
"""

predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = str(node_id))),
)

connections = []
Expand Down
2 changes: 1 addition & 1 deletion cognee/infrastructure/databases/graph/networkx/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ async def load_graph_from_file(self, file_path: str = None):
except:
pass

if "updated_at" in node:
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")

self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)
Expand Down
16 changes: 4 additions & 12 deletions cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,10 @@ def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> Lance
for (data_point_index, data_point) in enumerate(data_points)
]

# TODO: This enables us to work with pydantic version but shouldn't
# stay like this, existing rows should be updated

await collection.delete("id IS NOT NULL")

original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()

if new_size <= original_size:
raise ValueError(
"LanceDB create_datapoints error: data points did not get added.")
await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)


async def retrieve(self, collection_name: str, data_point_ids: list[str]):
Expand Down
Loading

0 comments on commit 3e7df33

Please sign in to comment.