diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 8e92d08e0..3d31b4000 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -3,6 +3,8 @@ from pathlib import Path from cognee.base_config import get_base_config +from cognee.infrastructure.databases.vector.embeddings import \ + get_embedding_engine from cognee.modules.cognify.config import get_cognify_config from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines.tasks.Task import Task @@ -15,8 +17,10 @@ from cognee.tasks.repo_processor import (enrich_dependency_graph, expand_dependency_graph, get_data_list_for_user, - get_non_code_files, + get_non_py_files, get_repo_file_dependencies) +from cognee.tasks.repo_processor.get_source_code_chunks import \ + get_source_code_chunks from cognee.tasks.storage import add_data_points monitoring = get_base_config().monitoring_tool @@ -28,6 +32,7 @@ logger = logging.getLogger("code_graph_pipeline") update_status_lock = asyncio.Lock() + @observe async def run_code_graph_pipeline(repo_path, include_docs=True): import os @@ -46,20 +51,23 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): await cognee.prune.prune_system(metadata=True) await create_db_and_tables() + embedding_engine = get_embedding_engine() + cognee_config = get_cognify_config() user = await get_default_user() tasks = [ Task(get_repo_file_dependencies), - Task(enrich_dependency_graph, task_config={"batch_size": 50}), + Task(enrich_dependency_graph), Task(expand_dependency_graph, task_config={"batch_size": 50}), + Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}), ] if include_docs: non_code_tasks = [ - Task(get_non_code_files, task_config={"batch_size": 50}), + Task(get_non_py_files, task_config={"batch_size": 50}), Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), @@ -71,7 +79,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): task_config={"batch_size": 50} ), ] - + if include_docs: async for result in run_tasks(non_code_tasks, repo_path): yield result diff --git a/cognee/infrastructure/databases/exceptions/EmbeddingException.py b/cognee/infrastructure/databases/exceptions/EmbeddingException.py new file mode 100644 index 000000000..ba7c70d80 --- /dev/null +++ b/cognee/infrastructure/databases/exceptions/EmbeddingException.py @@ -0,0 +1,3 @@ +class EmbeddingException(Exception): + """Custom exception for handling embedding-related errors.""" + pass \ No newline at end of file diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index dce12b318..93f59cc77 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -5,17 +5,19 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine +from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException litellm.set_verbose = False logger = logging.getLogger("LiteLLMEmbeddingEngine") + class LiteLLMEmbeddingEngine(EmbeddingEngine): api_key: str endpoint: str api_version: str model: str dimensions: int - mock:bool + mock: bool def __init__( self, @@ -33,7 +35,7 @@ def __init__( enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): - enable_mocking= str(enable_mocking).lower() + enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") MAX_RETRIES = 5 @@ -43,7 +45,7 @@ async def embed_text(self, text: List[str]) -> List[List[float]]: async def exponential_backoff(attempt): wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds await asyncio.sleep(wait_time) - + try: if self.mock: response = { @@ -56,10 +58,10 @@ async def exponential_backoff(attempt): else: response = await litellm.aembedding( self.model, - input = text, - api_key = self.api_key, - api_base = self.endpoint, - api_version = self.api_version + input=text, + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version ) self.retry_count = 0 @@ -71,7 +73,7 @@ async def exponential_backoff(attempt): if len(text) == 1: parts = [text] else: - parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]] + parts = [text[0:math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2):]] parts_futures = [self.embed_text(part) for part in parts] embeddings = await asyncio.gather(*parts_futures) @@ -95,6 +97,9 @@ async def exponential_backoff(attempt): return await self.embed_text(text) + except (litellm.exceptions.BadRequestError, litellm.llms.OpenAI.openai.OpenAIError): + raise EmbeddingException("Failed to index data points.") + except Exception as error: logger.error("Error embedding text: %s", str(error)) raise error diff --git a/cognee/shared/CodeGraphEntities.py b/cognee/shared/CodeGraphEntities.py index 23b8879c2..27289493d 100644 --- a/cognee/shared/CodeGraphEntities.py +++ b/cognee/shared/CodeGraphEntities.py @@ -1,5 +1,4 @@ from typing import List, Optional - from cognee.infrastructure.engine import DataPoint @@ -7,7 +6,7 @@ class Repository(DataPoint): __tablename__ = "Repository" path: str _metadata: dict = { - "index_fields": ["source_code"], + "index_fields": [], "type": "Repository" } @@ -19,29 +18,31 @@ class CodeFile(DataPoint): depends_on: Optional[List["CodeFile"]] = None depends_directly_on: Optional[List["CodeFile"]] = None contains: Optional[List["CodePart"]] = None - _metadata: dict = { - "index_fields": ["source_code"], + "index_fields": [], "type": "CodeFile" } class CodePart(DataPoint): __tablename__ = "codepart" - # part_of: Optional[CodeFile] - source_code: str - + # part_of: Optional[CodeFile] = None + source_code: Optional[str] = None _metadata: dict = { - "index_fields": ["source_code"], + "index_fields": [], "type": "CodePart" } -class CodeRelationship(DataPoint): - source_id: str - target_id: str - relation: str # depends on or depends directly +class SourceCodeChunk(DataPoint): + __tablename__ = "sourcecodechunk" + code_chunk_of: Optional[CodePart] = None + source_code: Optional[str] = None + previous_chunk: Optional["SourceCodeChunk"] = None + _metadata: dict = { - "type": "CodeRelationship" + "index_fields": ["source_code"], + "type": "SourceCodeChunk" } CodeFile.model_rebuild() CodePart.model_rebuild() +SourceCodeChunk.model_rebuild() \ No newline at end of file diff --git a/cognee/shared/data_models.py b/cognee/shared/data_models.py index dec53cfcb..2a8bc8c91 100644 --- a/cognee/shared/data_models.py +++ b/cognee/shared/data_models.py @@ -210,7 +210,6 @@ class SummarizedClass(BaseModel): decorators: Optional[List[str]] = None class SummarizedCode(BaseModel): - file_name: str high_level_summary: str key_features: List[str] imports: List[str] = [] diff --git a/cognee/tasks/repo_processor/get_repo_file_dependencies.py b/cognee/tasks/repo_processor/get_repo_file_dependencies.py index 221af6cf6..b54c1f152 100644 --- a/cognee/tasks/repo_processor/get_repo_file_dependencies.py +++ b/cognee/tasks/repo_processor/get_repo_file_dependencies.py @@ -71,7 +71,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non path = repo_path, ) - yield repo + yield [repo] with ProcessPoolExecutor(max_workers = 12) as executor: loop = asyncio.get_event_loop() @@ -90,10 +90,11 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non results = await asyncio.gather(*tasks) + code_files = [] for (file_path, metadata), dependencies in zip(py_files_dict.items(), results): source_code = metadata.get("source_code") - yield CodeFile( + code_files.append(CodeFile( id = uuid5(NAMESPACE_OID, file_path), source_code = source_code, extracted_id = file_path, @@ -106,4 +107,6 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non source_code = py_files_dict.get(dependency, {}).get("source_code"), ) for dependency in dependencies ] if dependencies else None, - ) + )) + + yield code_files diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py new file mode 100644 index 000000000..4d0ce3200 --- /dev/null +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -0,0 +1,164 @@ +import logging +from typing import AsyncGenerator, Generator +from uuid import NAMESPACE_OID, uuid5 + +import parso +import tiktoken + +from cognee.infrastructure.engine import DataPoint +from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk + +logger = logging.getLogger("task:get_source_code_chunks") + + +def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int: + return len(tokenizer.encode(source_code)) + + +def _get_naive_subchunk_token_counts( + tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000 +) -> list[tuple[str, int]]: + """Splits source code into subchunks of up to max_subchunk_tokens and counts tokens.""" + + token_ids = tokenizer.encode(source_code) + subchunk_token_counts = [] + + for start_idx in range(0, len(token_ids), max_subchunk_tokens): + subchunk_token_ids = token_ids[start_idx: start_idx + max_subchunk_tokens] + token_count = len(subchunk_token_ids) + subchunk = ''.join( + tokenizer.decode_single_token_bytes(token_id).decode('utf-8', errors='replace') + for token_id in subchunk_token_ids + ) + subchunk_token_counts.append((subchunk, token_count)) + + return subchunk_token_counts + + +def _get_subchunk_token_counts( + tokenizer: tiktoken.Encoding, + source_code: str, + max_subchunk_tokens: int = 8000, + depth: int = 0, + max_depth: int = 100 +) -> list[tuple[str, int]]: + """Splits source code into subchunk and counts tokens for each subchunk.""" + if depth > max_depth: + return _get_naive_subchunk_token_counts(tokenizer, source_code, max_subchunk_tokens) + + try: + module = parso.parse(source_code) + except Exception as e: + logger.error(f"Error parsing source code: {e}") + return [] + + if not module.children: + logger.warning("Parsed module has no children (empty or invalid source code).") + return [] + + # Handle cases with only one real child and an EndMarker to prevent infinite recursion. + if len(module.children) <= 2: + module = module.children[0] + + subchunk_token_counts = [] + for child in module.children: + subchunk = child.get_code() + token_count = _count_tokens(tokenizer, subchunk) + + if token_count == 0: + continue + + if token_count <= max_subchunk_tokens: + subchunk_token_counts.append((subchunk, token_count)) + continue + + if child.type == 'string': + subchunk_token_counts.extend(_get_naive_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens)) + continue + + subchunk_token_counts.extend( + _get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth) + ) + + return subchunk_token_counts + + +def _get_chunk_source_code( + code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int +) -> tuple[list[tuple[str, int]], str]: + """Generates a chunk of source code from tokenized subchunks with overlap handling.""" + current_count = 0 + cumulative_counts = [] + current_source_code = '' + + for i, (child_code, token_count) in enumerate(code_token_counts): + current_count += token_count + cumulative_counts.append(current_count) + if current_count > max_tokens: + break + current_source_code += f"\n{child_code}" + + if current_count <= max_tokens: + return [], current_source_code.strip() + + cutoff = 1 + for i, cum_count in enumerate(cumulative_counts): + if cum_count > (1 - overlap) * max_tokens: + break + cutoff = i + + return code_token_counts[cutoff:], current_source_code.strip() + + +def get_source_code_chunks_from_code_part( + code_file_part: CodePart, + max_tokens: int = 8192, + overlap: float = 0.25, + granularity: float = 0.1, + model_name: str = "text-embedding-3-large" +) -> Generator[SourceCodeChunk, None, None]: + """Yields source code chunks from a CodePart object, with configurable token limits and overlap.""" + if not code_file_part.source_code: + logger.error(f"No source code in CodeFile {code_file_part.id}") + return + + tokenizer = tiktoken.encoding_for_model(model_name) + max_subchunk_tokens = max(1, int(granularity * max_tokens)) + subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens) + + previous_chunk = None + while subchunk_token_counts: + subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens) + if not chunk_source_code: + continue + current_chunk = SourceCodeChunk( + id=uuid5(NAMESPACE_OID, chunk_source_code), + code_chunk_of=code_file_part, + source_code=chunk_source_code, + previous_chunk=previous_chunk + ) + yield current_chunk + previous_chunk = current_chunk + + +async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \ + AsyncGenerator[list[DataPoint], None]: + """Processes code graph datapoints, create SourceCodeChink datapoints.""" + # TODO: Add support for other embedding models, with max_token mapping + for data_point in data_points: + try: + yield data_point + if not isinstance(data_point, CodeFile): + continue + if not data_point.contains: + logger.warning(f"CodeFile {data_point.id} contains no code parts") + continue + for code_part in data_point.contains: + try: + yield code_part + for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model): + yield source_code_chunk + except Exception as e: + logger.error(f"Error processing code part: {e}") + except Exception as e: + logger.error(f"Error processing data point: {e}") diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 857e4d777..12af2d2ef 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -1,6 +1,10 @@ +import logging + +from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint +logger = logging.getLogger("index_data_points") async def index_data_points(data_points: list[DataPoint]): created_indexes = {} @@ -30,7 +34,10 @@ async def index_data_points(data_points: list[DataPoint]): for index_name, indexable_points in index_points.items(): index_name, field_name = index_name.split(".") - await vector_engine.index_data_points(index_name, field_name, indexable_points) + try: + await vector_engine.index_data_points(index_name, field_name, indexable_points) + except EmbeddingException as e: + logger.warning(f"Failed to index data points for {index_name}.{field_name}: {e}") return data_points diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index add448155..5b0345015 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -1,6 +1,8 @@ +from typing import Union + from cognee.infrastructure.engine import DataPoint from cognee.modules.chunking.models import DocumentChunk -from cognee.shared.CodeGraphEntities import CodeFile +from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk class TextSummary(DataPoint): @@ -17,7 +19,7 @@ class TextSummary(DataPoint): class CodeSummary(DataPoint): __tablename__ = "code_summary" text: str - made_from: CodeFile + summarizes: Union[CodeFile, CodePart, SourceCodeChunk] _metadata: dict = { "index_fields": ["text"], diff --git a/cognee/tasks/summarization/summarize_code.py b/cognee/tasks/summarization/summarize_code.py index b116e57a9..9efc5b6ca 100644 --- a/cognee/tasks/summarization/summarize_code.py +++ b/cognee/tasks/summarization/summarize_code.py @@ -1,10 +1,10 @@ import asyncio from typing import AsyncGenerator, Union from uuid import uuid5 -from typing import Type from cognee.infrastructure.engine import DataPoint from cognee.modules.data.extraction.extract_summary import extract_code_summary + from .models import CodeSummary @@ -21,7 +21,7 @@ async def summarize_code( ) file_summaries_map = { - code_data_point.extracted_id: str(file_summary) + code_data_point.id: str(file_summary) for code_data_point, file_summary in zip(code_data_points, file_summaries) } @@ -35,6 +35,6 @@ async def summarize_code( yield CodeSummary( id=uuid5(node.id, "CodeSummary"), - made_from=node, - text=file_summaries_map[node.extracted_id], + summarizes=node, + text=file_summaries_map[node.id], ) diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py index c0b91972b..44ab33aad 100644 --- a/examples/python/code_graph_example.py +++ b/examples/python/code_graph_example.py @@ -11,6 +11,6 @@ async def main(repo_path, include_docs): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository") - parser.add_argument("--include_docs", type=bool, default=True, help="Whether or not to process non-code files") + parser.add_argument("--include_docs", type=lambda x: x.lower() in ("true", "1"), default=True, help="Whether or not to process non-code files") args = parser.parse_args() asyncio.run(main(args.repo_path, args.include_docs)) \ No newline at end of file