Skip to content

Commit

Permalink
Feature/cog 539 implementing additional retriever approaches (#262)
Browse files Browse the repository at this point in the history
* fix: refactor get_graph_from_model to return nodes and edges correctly

* fix: add missing params

* fix: remove complex zip usage

* fix: add edges to data_point properties

* fix: handle rate limit error coming from llm model

* fix: fixes lost edges and nodes in get_graph_from_model

* fix: fixes database pruning issue in pgvector

* fix: fixes database pruning issue in pgvector (#261)

* feat: adds code summary embeddings to vector DB

* fix: cognee_demo notebook pipeline is not saving summaries

* feat: implements first version of codegraph retriever

* chore: implements minor changes mostly to make the code production ready

* fix: turns off raising duplicated edges unit test as we have these in our current codegraph generation

* feat: implements unit tests for description to codepart search

* fix: fixes edge property inconsistent access in codepart retriever

* chore: implements more precise typing for get_attribute method for cogneegraph

* chore: adds spacing to tests and changes the cogneegraph getter names

---------

Co-authored-by: Boris Arzentar <[email protected]>
  • Loading branch information
hajdul88 and borisarzentar authored Dec 10, 2024
1 parent 5ffbebd commit 6d85165
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cognee/modules/graph/cognee_graph/CogneeGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_edge(self, edge: Edge) -> None:
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
else:
raise EntityAlreadyExistsError(message=f"Edge {edge} already exists in the graph.")
print(f"Edge {edge} already exists in the graph.")

def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
Expand Down
16 changes: 14 additions & 2 deletions cognee/modules/graph/cognee_graph/CogneeGraphElements.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def add_attribute(self, key: str, value: Any) -> None:
def get_attribute(self, key: str) -> Union[str, int, float]:
return self.attributes[key]

def get_skeleton_edges(self):
return self.skeleton_edges

def get_skeleton_neighbours(self):
return self.skeleton_neighbours

def __repr__(self) -> str:
return f"Node({self.id}, attributes={self.attributes})"

Expand Down Expand Up @@ -109,8 +115,14 @@ def is_edge_alive_in_dimension(self, dimension: int) -> bool:
def add_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value

def get_attribute(self, key: str, value: Any) -> Union[str, int, float]:
return self.attributes[key]
def get_attribute(self, key: str) -> Optional[Union[str, int, float]]:
return self.attributes.get(key)

def get_source_node(self):
return self.node1

def get_destination_node(self):
return self.node2

def __repr__(self) -> str:
direction = "->" if self.directed else "--"
Expand Down
116 changes: 116 additions & 0 deletions cognee/modules/retrieval/description_to_codepart_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import logging

from typing import Set, List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry


async def code_description_to_code_part_search(query: str, user: User = None, top_k = 2) -> list:
if user is None:
user = await get_default_user()

if user is None:
raise PermissionError("No user found in the system. Please create a user.")

retrieved_codeparts = await code_description_to_code_part(query, user, top_k)
return retrieved_codeparts



async def code_description_to_code_part(
query: str,
user: User,
top_k: int
) -> List[str]:
"""
Maps a code description query to relevant code parts using a CodeGraph pipeline.
Args:
query (str): The search query describing the code parts.
user (User): The user performing the search.
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
Returns:
Set[str]: A set of unique code parts matching the query.
Raises:
ValueError: If arguments are invalid.
RuntimeError: If an unexpected error occurs during execution.
"""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
if top_k <= 0 or not isinstance(top_k, int):
raise ValueError("top_k must be a positive integer.")

try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as init_error:
logging.error("Failed to initialize engines: %s", init_error, exc_info=True)
raise RuntimeError("System initialization error. Please try again later.") from init_error

send_telemetry("code_description_to_code_part_search EXECUTION STARTED", user.id)
logging.info("Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k)

try:
results = await vector_engine.search(
"code_summary_text", query_text=query, limit=top_k
)
if not results:
logging.warning("No results found for query: '%s' by user: %s", query, user.id)
return []

memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=['id', 'type', 'text', 'source_code'],
edge_properties_to_project=['relationship_name']
)

code_pieces_to_return = set()

for node in results:
node_id = str(node.id)
node_to_search_from = memory_fragment.get_node(node_id)

if not node_to_search_from:
logging.debug("Node %s not found in memory fragment graph", node_id)
continue

for code_file in node_to_search_from.get_skeleton_neighbours():
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute('relationship_name') == 'contains':
code_pieces_to_return.add(code_file_edge.get_destination_node())

logging.info("Search completed for user: %s, query: '%s'. Found %d code pieces.",
user.id, query, len(code_pieces_to_return))

return list(code_pieces_to_return)

except Exception as exec_error:
logging.error(
"Error during code description to code part search for user: %s, query: '%s'. Error: %s",
user.id, query, exec_error, exc_info=True
)
send_telemetry("code_description_to_code_part_search EXECUTION FAILED", user.id)
raise RuntimeError("An error occurred while processing your request.") from exec_error


if __name__ == "__main__":
async def main():
query = "I am looking for a class with blue eyes"
user = None
try:
results = await code_description_to_code_part_search(query, user)
print("Retrieved Code Parts:", results)
except Exception as e:
print(f"An error occurred: {e}")

asyncio.run(main())


1 change: 1 addition & 0 deletions cognee/tasks/summarization/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class TextSummary(DataPoint):


class CodeSummary(DataPoint):
__tablename__ = "code_summary"
text: str
made_from: CodeFile

Expand Down
13 changes: 0 additions & 13 deletions cognee/tests/unit/modules/graph/cognee_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,6 @@ def test_add_edge_success(setup_graph):
assert edge in node2.skeleton_edges


def test_add_duplicate_edge(setup_graph):
"""Test adding a duplicate edge raises an exception."""
graph = setup_graph
node1 = Node("node1")
node2 = Node("node2")
graph.add_node(node1)
graph.add_node(node2)
edge = Edge(node1, node2)
graph.add_edge(edge)
with pytest.raises(EntityAlreadyExistsError, match="Edge .* already exists in the graph."):
graph.add_edge(edge)


def test_get_node_success(setup_graph):
"""Test retrieving an existing node."""
graph = setup_graph
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
from unittest.mock import AsyncMock, patch



@pytest.mark.asyncio
async def test_code_description_to_code_part_no_results():
"""Test that code_description_to_code_part handles no search results."""

mock_user = AsyncMock()
mock_user.id = "user123"
mock_vector_engine = AsyncMock()
mock_vector_engine.search.return_value = []

with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):

from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
result = await code_description_to_code_part("search query", mock_user, 2)

assert result == []


@pytest.mark.asyncio
async def test_code_description_to_code_part_invalid_query():
"""Test that code_description_to_code_part raises ValueError for invalid query."""

mock_user = AsyncMock()

with pytest.raises(ValueError, match="The query must be a non-empty string."):
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
await code_description_to_code_part("", mock_user, 2)


@pytest.mark.asyncio
async def test_code_description_to_code_part_invalid_top_k():
"""Test that code_description_to_code_part raises ValueError for invalid top_k."""

mock_user = AsyncMock()

with pytest.raises(ValueError, match="top_k must be a positive integer."):
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
await code_description_to_code_part("search query", mock_user, 0)


@pytest.mark.asyncio
async def test_code_description_to_code_part_initialization_error():
"""Test that code_description_to_code_part raises RuntimeError for engine initialization errors."""

mock_user = AsyncMock()

with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed")), \
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()):

from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
with pytest.raises(RuntimeError, match="System initialization error. Please try again later."):
await code_description_to_code_part("search query", mock_user, 2)


@pytest.mark.asyncio
async def test_code_description_to_code_part_execution_error():
"""Test that code_description_to_code_part raises RuntimeError for execution errors."""

mock_user = AsyncMock()
mock_user.id = "user123"
mock_vector_engine = AsyncMock()
mock_vector_engine.search.side_effect = Exception("Execution error")

with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()):

from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part
with pytest.raises(RuntimeError, match="An error occurred while processing your request."):
await code_description_to_code_part("search query", mock_user, 2)

0 comments on commit 6d85165

Please sign in to comment.