-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/cog 539 implementing additional retriever approaches (#262)
* fix: refactor get_graph_from_model to return nodes and edges correctly * fix: add missing params * fix: remove complex zip usage * fix: add edges to data_point properties * fix: handle rate limit error coming from llm model * fix: fixes lost edges and nodes in get_graph_from_model * fix: fixes database pruning issue in pgvector * fix: fixes database pruning issue in pgvector (#261) * feat: adds code summary embeddings to vector DB * fix: cognee_demo notebook pipeline is not saving summaries * feat: implements first version of codegraph retriever * chore: implements minor changes mostly to make the code production ready * fix: turns off raising duplicated edges unit test as we have these in our current codegraph generation * feat: implements unit tests for description to codepart search * fix: fixes edge property inconsistent access in codepart retriever * chore: implements more precise typing for get_attribute method for cogneegraph * chore: adds spacing to tests and changes the cogneegraph getter names --------- Co-authored-by: Boris Arzentar <[email protected]>
- Loading branch information
1 parent
5ffbebd
commit 6d85165
Showing
6 changed files
with
208 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
cognee/modules/retrieval/description_to_codepart_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import asyncio | ||
import logging | ||
|
||
from typing import Set, List | ||
from cognee.infrastructure.databases.graph import get_graph_engine | ||
from cognee.infrastructure.databases.vector import get_vector_engine | ||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph | ||
from cognee.modules.users.methods import get_default_user | ||
from cognee.modules.users.models import User | ||
from cognee.shared.utils import send_telemetry | ||
|
||
|
||
async def code_description_to_code_part_search(query: str, user: User = None, top_k = 2) -> list: | ||
if user is None: | ||
user = await get_default_user() | ||
|
||
if user is None: | ||
raise PermissionError("No user found in the system. Please create a user.") | ||
|
||
retrieved_codeparts = await code_description_to_code_part(query, user, top_k) | ||
return retrieved_codeparts | ||
|
||
|
||
|
||
async def code_description_to_code_part( | ||
query: str, | ||
user: User, | ||
top_k: int | ||
) -> List[str]: | ||
""" | ||
Maps a code description query to relevant code parts using a CodeGraph pipeline. | ||
Args: | ||
query (str): The search query describing the code parts. | ||
user (User): The user performing the search. | ||
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher) | ||
Returns: | ||
Set[str]: A set of unique code parts matching the query. | ||
Raises: | ||
ValueError: If arguments are invalid. | ||
RuntimeError: If an unexpected error occurs during execution. | ||
""" | ||
if not query or not isinstance(query, str): | ||
raise ValueError("The query must be a non-empty string.") | ||
if top_k <= 0 or not isinstance(top_k, int): | ||
raise ValueError("top_k must be a positive integer.") | ||
|
||
try: | ||
vector_engine = get_vector_engine() | ||
graph_engine = await get_graph_engine() | ||
except Exception as init_error: | ||
logging.error("Failed to initialize engines: %s", init_error, exc_info=True) | ||
raise RuntimeError("System initialization error. Please try again later.") from init_error | ||
|
||
send_telemetry("code_description_to_code_part_search EXECUTION STARTED", user.id) | ||
logging.info("Search initiated by user %s with query: '%s' and top_k: %d", user.id, query, top_k) | ||
|
||
try: | ||
results = await vector_engine.search( | ||
"code_summary_text", query_text=query, limit=top_k | ||
) | ||
if not results: | ||
logging.warning("No results found for query: '%s' by user: %s", query, user.id) | ||
return [] | ||
|
||
memory_fragment = CogneeGraph() | ||
await memory_fragment.project_graph_from_db( | ||
graph_engine, | ||
node_properties_to_project=['id', 'type', 'text', 'source_code'], | ||
edge_properties_to_project=['relationship_name'] | ||
) | ||
|
||
code_pieces_to_return = set() | ||
|
||
for node in results: | ||
node_id = str(node.id) | ||
node_to_search_from = memory_fragment.get_node(node_id) | ||
|
||
if not node_to_search_from: | ||
logging.debug("Node %s not found in memory fragment graph", node_id) | ||
continue | ||
|
||
for code_file in node_to_search_from.get_skeleton_neighbours(): | ||
for code_file_edge in code_file.get_skeleton_edges(): | ||
if code_file_edge.get_attribute('relationship_name') == 'contains': | ||
code_pieces_to_return.add(code_file_edge.get_destination_node()) | ||
|
||
logging.info("Search completed for user: %s, query: '%s'. Found %d code pieces.", | ||
user.id, query, len(code_pieces_to_return)) | ||
|
||
return list(code_pieces_to_return) | ||
|
||
except Exception as exec_error: | ||
logging.error( | ||
"Error during code description to code part search for user: %s, query: '%s'. Error: %s", | ||
user.id, query, exec_error, exc_info=True | ||
) | ||
send_telemetry("code_description_to_code_part_search EXECUTION FAILED", user.id) | ||
raise RuntimeError("An error occurred while processing your request.") from exec_error | ||
|
||
|
||
if __name__ == "__main__": | ||
async def main(): | ||
query = "I am looking for a class with blue eyes" | ||
user = None | ||
try: | ||
results = await code_description_to_code_part_search(query, user) | ||
print("Retrieved Code Parts:", results) | ||
except Exception as e: | ||
print(f"An error occurred: {e}") | ||
|
||
asyncio.run(main()) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
cognee/tests/unit/modules/retriever/test_description_to_codepart_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
from unittest.mock import AsyncMock, patch | ||
|
||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_code_description_to_code_part_no_results(): | ||
"""Test that code_description_to_code_part handles no search results.""" | ||
|
||
mock_user = AsyncMock() | ||
mock_user.id = "user123" | ||
mock_vector_engine = AsyncMock() | ||
mock_vector_engine.search.return_value = [] | ||
|
||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \ | ||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \ | ||
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()): | ||
|
||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part | ||
result = await code_description_to_code_part("search query", mock_user, 2) | ||
|
||
assert result == [] | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_code_description_to_code_part_invalid_query(): | ||
"""Test that code_description_to_code_part raises ValueError for invalid query.""" | ||
|
||
mock_user = AsyncMock() | ||
|
||
with pytest.raises(ValueError, match="The query must be a non-empty string."): | ||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part | ||
await code_description_to_code_part("", mock_user, 2) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_code_description_to_code_part_invalid_top_k(): | ||
"""Test that code_description_to_code_part raises ValueError for invalid top_k.""" | ||
|
||
mock_user = AsyncMock() | ||
|
||
with pytest.raises(ValueError, match="top_k must be a positive integer."): | ||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part | ||
await code_description_to_code_part("search query", mock_user, 0) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_code_description_to_code_part_initialization_error(): | ||
"""Test that code_description_to_code_part raises RuntimeError for engine initialization errors.""" | ||
|
||
mock_user = AsyncMock() | ||
|
||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", side_effect=Exception("Engine init failed")), \ | ||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()): | ||
|
||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part | ||
with pytest.raises(RuntimeError, match="System initialization error. Please try again later."): | ||
await code_description_to_code_part("search query", mock_user, 2) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_code_description_to_code_part_execution_error(): | ||
"""Test that code_description_to_code_part raises RuntimeError for execution errors.""" | ||
|
||
mock_user = AsyncMock() | ||
mock_user.id = "user123" | ||
mock_vector_engine = AsyncMock() | ||
mock_vector_engine.search.side_effect = Exception("Execution error") | ||
|
||
with patch("cognee.modules.retrieval.description_to_codepart_search.get_vector_engine", return_value=mock_vector_engine), \ | ||
patch("cognee.modules.retrieval.description_to_codepart_search.get_graph_engine", return_value=AsyncMock()), \ | ||
patch("cognee.modules.retrieval.description_to_codepart_search.CogneeGraph", return_value=AsyncMock()): | ||
|
||
from cognee.modules.retrieval.description_to_codepart_search import code_description_to_code_part | ||
with pytest.raises(RuntimeError, match="An error occurred while processing your request."): | ||
await code_description_to_code_part("search query", mock_user, 2) |