diff --git a/cognee/api/v1/cognify/code_graph_pipeline.py b/cognee/api/v1/cognify/code_graph_pipeline.py index 405cb0b4..2d077f39 100644 --- a/cognee/api/v1/cognify/code_graph_pipeline.py +++ b/cognee/api/v1/cognify/code_graph_pipeline.py @@ -3,7 +3,6 @@ from pathlib import Path from cognee.base_config import get_base_config -from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.modules.cognify.config import get_cognify_config from cognee.modules.pipelines import run_tasks from cognee.modules.pipelines.tasks.Task import Task @@ -54,8 +53,6 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): await cognee.prune.prune_system(metadata=True) await create_db_and_tables() - embedding_engine = get_embedding_engine() - cognee_config = get_cognify_config() user = await get_default_user() @@ -63,11 +60,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(get_repo_file_dependencies), Task(enrich_dependency_graph), Task(expand_dependency_graph, task_config={"batch_size": 50}), - Task( - get_source_code_chunks, - embedding_model=embedding_engine.model, - task_config={"batch_size": 50}, - ), + Task(get_source_code_chunks, task_config={"batch_size": 50}), Task(summarize_code, task_config={"batch_size": 50}), Task(add_data_points, task_config={"batch_size": 50}), ] @@ -78,7 +71,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True): Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(classify_documents), - Task(extract_chunks_from_documents), + Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens), Task( extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} ), diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 7bb8a1c1..78c02b9c 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -1,8 +1,10 @@ -from uuid import uuid5, NAMESPACE_OID +from typing import Optional +from uuid import NAMESPACE_OID, uuid5 -from .models.DocumentChunk import DocumentChunk from cognee.tasks.chunks import chunk_by_paragraph +from .models.DocumentChunk import DocumentChunk + class TextChunker: document = None @@ -10,23 +12,36 @@ class TextChunker: chunk_index = 0 chunk_size = 0 + token_count = 0 - def __init__(self, document, get_text: callable, chunk_size: int = 1024): + def __init__( + self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024 + ): self.document = document self.max_chunk_size = chunk_size self.get_text = get_text + self.max_tokens = max_tokens if max_tokens else float("inf") + + def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): + word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size + token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens + return word_count_fits and token_count_fits def read(self): paragraph_chunks = [] for content_text in self.get_text(): for chunk_data in chunk_by_paragraph( content_text, + self.max_tokens, self.max_chunk_size, batch_paragraphs=True, ): - if self.chunk_size + chunk_data["word_count"] <= self.max_chunk_size: + if self.check_word_count_and_token_count( + self.chunk_size, self.token_count, chunk_data + ): paragraph_chunks.append(chunk_data) self.chunk_size += chunk_data["word_count"] + self.token_count += chunk_data["token_count"] else: if len(paragraph_chunks) == 0: yield DocumentChunk( @@ -66,6 +81,7 @@ def read(self): print(e) paragraph_chunks = [chunk_data] self.chunk_size = chunk_data["word_count"] + self.token_count = chunk_data["token_count"] self.chunk_index += 1 diff --git a/cognee/modules/cognify/config.py b/cognee/modules/cognify/config.py index d40410bf..dd94d8b4 100644 --- a/cognee/modules/cognify/config.py +++ b/cognee/modules/cognify/config.py @@ -1,12 +1,14 @@ from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict from cognee.shared.data_models import DefaultContentPrediction, SummarizedContent +from typing import Optional +import os class CognifyConfig(BaseSettings): classification_model: object = DefaultContentPrediction summarization_model: object = SummarizedContent - + max_tokens: Optional[int] = os.getenv("MAX_TOKENS") model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: diff --git a/cognee/modules/data/processing/document_types/AudioDocument.py b/cognee/modules/data/processing/document_types/AudioDocument.py index a33d4e7f..b7d2476b 100644 --- a/cognee/modules/data/processing/document_types/AudioDocument.py +++ b/cognee/modules/data/processing/document_types/AudioDocument.py @@ -1,6 +1,9 @@ +from typing import Optional + from cognee.infrastructure.llm.get_llm_client import get_llm_client -from .Document import Document + from .ChunkerMapping import ChunkerConfig +from .Document import Document class AudioDocument(Document): @@ -10,12 +13,14 @@ def create_transcript(self): result = get_llm_client().create_transcript(self.raw_data_location) return result.text - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): # Transcribe the audio file text = self.create_transcript() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/Document.py b/cognee/modules/data/processing/document_types/Document.py index 08380e80..7ecdf289 100644 --- a/cognee/modules/data/processing/document_types/Document.py +++ b/cognee/modules/data/processing/document_types/Document.py @@ -1,3 +1,4 @@ +from typing import Optional from uuid import UUID from cognee.infrastructure.engine import DataPoint @@ -10,5 +11,5 @@ class Document(DataPoint): mime_type: str _metadata: dict = {"index_fields": ["name"], "type": "Document"} - def read(self, chunk_size: int, chunker=str) -> str: + def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str: pass diff --git a/cognee/modules/data/processing/document_types/ImageDocument.py b/cognee/modules/data/processing/document_types/ImageDocument.py index 424cd059..c055b825 100644 --- a/cognee/modules/data/processing/document_types/ImageDocument.py +++ b/cognee/modules/data/processing/document_types/ImageDocument.py @@ -1,6 +1,9 @@ +from typing import Optional + from cognee.infrastructure.llm.get_llm_client import get_llm_client -from .Document import Document + from .ChunkerMapping import ChunkerConfig +from .Document import Document class ImageDocument(Document): @@ -10,11 +13,13 @@ def transcribe_image(self): result = get_llm_client().transcribe_image(self.raw_data_location) return result.choices[0].message.content - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): # Transcribe the image file text = self.transcribe_image() chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=lambda: [text]) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/PdfDocument.py b/cognee/modules/data/processing/document_types/PdfDocument.py index 684fb428..768f9126 100644 --- a/cognee/modules/data/processing/document_types/PdfDocument.py +++ b/cognee/modules/data/processing/document_types/PdfDocument.py @@ -1,12 +1,15 @@ +from typing import Optional + from pypdf import PdfReader -from .Document import Document + from .ChunkerMapping import ChunkerConfig +from .Document import Document class PdfDocument(Document): type: str = "pdf" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): file = PdfReader(self.raw_data_location) def get_text(): @@ -15,7 +18,9 @@ def get_text(): yield page_text chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/TextDocument.py b/cognee/modules/data/processing/document_types/TextDocument.py index f993ff22..b62ccd56 100644 --- a/cognee/modules/data/processing/document_types/TextDocument.py +++ b/cognee/modules/data/processing/document_types/TextDocument.py @@ -1,11 +1,13 @@ -from .Document import Document +from typing import Optional + from .ChunkerMapping import ChunkerConfig +from .Document import Document class TextDocument(Document): type: str = "text" - def read(self, chunk_size: int, chunker: str): + def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def get_text(): with open(self.raw_data_location, mode="r", encoding="utf-8") as file: while True: @@ -18,6 +20,8 @@ def get_text(): chunker_func = ChunkerConfig.get_chunker(chunker) - chunker = chunker_func(self, chunk_size=chunk_size, get_text=get_text) + chunker = chunker_func( + self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens + ) yield from chunker.read() diff --git a/cognee/modules/data/processing/document_types/UnstructuredDocument.py b/cognee/modules/data/processing/document_types/UnstructuredDocument.py index cd5c72e3..1c291d0d 100644 --- a/cognee/modules/data/processing/document_types/UnstructuredDocument.py +++ b/cognee/modules/data/processing/document_types/UnstructuredDocument.py @@ -1,14 +1,16 @@ from io import StringIO +from typing import Optional from cognee.modules.chunking.TextChunker import TextChunker -from .Document import Document from cognee.modules.data.exceptions import UnstructuredLibraryImportError +from .Document import Document + class UnstructuredDocument(Document): type: str = "unstructured" - def read(self, chunk_size: int): + def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str: def get_text(): try: from unstructured.partition.auto import partition @@ -27,6 +29,6 @@ def get_text(): yield text - chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text) + chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens) yield from chunker.read() diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 5c95e97b..24d56607 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -1,10 +1,18 @@ -from uuid import uuid5, NAMESPACE_OID -from typing import Dict, Any, Iterator +from typing import Any, Dict, Iterator, Optional, Union +from uuid import NAMESPACE_OID, uuid5 + +import tiktoken + +from cognee.infrastructure.databases.vector import get_vector_engine + from .chunk_by_sentence import chunk_by_sentence def chunk_by_paragraph( - data: str, paragraph_length: int = 1024, batch_paragraphs: bool = True + data: str, + max_tokens: Optional[Union[int, float]] = None, + paragraph_length: int = 1024, + batch_paragraphs: bool = True, ) -> Iterator[Dict[str, Any]]: """ Chunks text by paragraph while preserving exact text reconstruction capability. @@ -15,16 +23,31 @@ def chunk_by_paragraph( chunk_index = 0 paragraph_ids = [] last_cut_type = None + current_token_count = 0 + if not max_tokens: + max_tokens = float("inf") + + vector_engine = get_vector_engine() + embedding_model = vector_engine.embedding_engine.model + embedding_model = embedding_model.split("/")[-1] for paragraph_id, sentence, word_count, end_type in chunk_by_sentence( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - if current_word_count > 0 and current_word_count + word_count > paragraph_length: + + tokenizer = tiktoken.encoding_for_model(embedding_model) + token_count = len(tokenizer.encode(sentence)) + + if current_word_count > 0 and ( + current_word_count + word_count > paragraph_length + or current_token_count + token_count > max_tokens + ): # Yield current chunk chunk_dict = { "text": current_chunk, "word_count": current_word_count, + "token_count": current_token_count, "chunk_id": uuid5(NAMESPACE_OID, current_chunk), "paragraph_ids": paragraph_ids, "chunk_index": chunk_index, @@ -37,11 +60,13 @@ def chunk_by_paragraph( paragraph_ids = [] current_chunk = "" current_word_count = 0 + current_token_count = 0 chunk_index += 1 paragraph_ids.append(paragraph_id) current_chunk += sentence current_word_count += word_count + current_token_count += token_count # Handle end of paragraph if end_type in ("paragraph_end", "sentence_cut") and not batch_paragraphs: @@ -49,6 +74,7 @@ def chunk_by_paragraph( chunk_dict = { "text": current_chunk, "word_count": current_word_count, + "token_count": current_token_count, "paragraph_ids": paragraph_ids, "chunk_id": uuid5(NAMESPACE_OID, current_chunk), "chunk_index": chunk_index, @@ -58,6 +84,7 @@ def chunk_by_paragraph( paragraph_ids = [] current_chunk = "" current_word_count = 0 + current_token_count = 0 chunk_index += 1 last_cut_type = end_type @@ -67,6 +94,7 @@ def chunk_by_paragraph( chunk_dict = { "text": current_chunk, "word_count": current_word_count, + "token_count": current_token_count, "chunk_id": uuid5(NAMESPACE_OID, current_chunk), "paragraph_ids": paragraph_ids, "chunk_index": chunk_index, diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 437d2a3e..5ce22400 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -1,9 +1,16 @@ +from typing import Optional + from cognee.modules.data.processing.document_types.Document import Document async def extract_chunks_from_documents( - documents: list[Document], chunk_size: int = 1024, chunker="text_chunker" + documents: list[Document], + chunk_size: int = 1024, + chunker="text_chunker", + max_tokens: Optional[int] = None, ): for document in documents: - for document_chunk in document.read(chunk_size=chunk_size, chunker=chunker): + for document_chunk in document.read( + chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens + ): yield document_chunk diff --git a/cognee/tasks/repo_processor/get_non_code_files.py b/cognee/tasks/repo_processor/get_non_code_files.py index 9c69afd0..12f32e84 100644 --- a/cognee/tasks/repo_processor/get_non_code_files.py +++ b/cognee/tasks/repo_processor/get_non_code_files.py @@ -29,8 +29,105 @@ async def get_non_py_files(repo_path): "*.egg-info", } + ALLOWED_EXTENSIONS = { + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".yaml", + ".yml", + ".html", + ".css", + ".js", + ".ts", + ".jsx", + ".tsx", + ".sql", + ".log", + ".ini", + ".toml", + ".properties", + ".sh", + ".bash", + ".dockerfile", + ".gitignore", + ".gitattributes", + ".makefile", + ".pyproject", + ".requirements", + ".env", + ".pdf", + ".doc", + ".docx", + ".dot", + ".dotx", + ".rtf", + ".wps", + ".wpd", + ".odt", + ".ott", + ".ottx", + ".txt", + ".wp", + ".sdw", + ".sdx", + ".docm", + ".dotm", + # Additional extensions for other programming languages + ".java", + ".c", + ".cpp", + ".h", + ".cs", + ".go", + ".php", + ".rb", + ".swift", + ".pl", + ".lua", + ".rs", + ".scala", + ".kt", + ".sh", + ".sql", + ".v", + ".asm", + ".pas", + ".d", + ".ml", + ".clj", + ".cljs", + ".erl", + ".ex", + ".exs", + ".f", + ".fs", + ".r", + ".pyi", + ".pdb", + ".ipynb", + ".rmd", + ".cabal", + ".hs", + ".nim", + ".vhdl", + ".verilog", + ".svelte", + ".html", + ".css", + ".scss", + ".less", + ".json5", + ".yaml", + ".yml", + } + def should_process(path): - return not any(pattern in path for pattern in IGNORED_PATTERNS) + _, ext = os.path.splitext(path) + return ext in ALLOWED_EXTENSIONS and not any( + pattern in path for pattern in IGNORED_PATTERNS + ) non_py_files_paths = [ os.path.join(root, file) diff --git a/cognee/tasks/repo_processor/get_source_code_chunks.py b/cognee/tasks/repo_processor/get_source_code_chunks.py index 980a8653..82fa46cf 100644 --- a/cognee/tasks/repo_processor/get_source_code_chunks.py +++ b/cognee/tasks/repo_processor/get_source_code_chunks.py @@ -5,6 +5,7 @@ import parso import tiktoken +from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk @@ -126,6 +127,9 @@ def get_source_code_chunks_from_code_part( logger.error(f"No source code in CodeFile {code_file_part.id}") return + vector_engine = get_vector_engine() + embedding_model = vector_engine.embedding_engine.model + model_name = embedding_model.split("/")[-1] tokenizer = tiktoken.encoding_for_model(model_name) max_subchunk_tokens = max(1, int(granularity * max_tokens)) subchunk_token_counts = _get_subchunk_token_counts( @@ -150,7 +154,7 @@ def get_source_code_chunks_from_code_part( async def get_source_code_chunks( - data_points: list[DataPoint], embedding_model="text-embedding-3-large" + data_points: list[DataPoint], ) -> AsyncGenerator[list[DataPoint], None]: """Processes code graph datapoints, create SourceCodeChink datapoints.""" # TODO: Add support for other embedding models, with max_token mapping @@ -165,9 +169,7 @@ async def get_source_code_chunks( for code_part in data_point.contains: try: yield code_part - for source_code_chunk in get_source_code_chunks_from_code_part( - code_part, model_name=embedding_model - ): + for source_code_chunk in get_source_code_chunks_from_code_part(code_part): yield source_code_chunk except Exception as e: logger.error(f"Error processing code part: {e}") diff --git a/cognee/tests/integration/documents/UnstructuredDocument_test.py b/cognee/tests/integration/documents/UnstructuredDocument_test.py index 03b8deb4..e0278de8 100644 --- a/cognee/tests/integration/documents/UnstructuredDocument_test.py +++ b/cognee/tests/integration/documents/UnstructuredDocument_test.py @@ -68,7 +68,7 @@ def test_UnstructuredDocument(): ) # Test PPTX - for paragraph_data in pptx_document.read(chunk_size=1024): + for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert ( @@ -76,7 +76,7 @@ def test_UnstructuredDocument(): ), f" sentence_cut != {paragraph_data.cut_type = }" # Test DOCX - for paragraph_data in docx_document.read(chunk_size=1024): + for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert ( @@ -84,7 +84,7 @@ def test_UnstructuredDocument(): ), f" sentence_end != {paragraph_data.cut_type = }" # TEST CSV - for paragraph_data in csv_document.read(chunk_size=1024): + for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert ( "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text @@ -94,7 +94,7 @@ def test_UnstructuredDocument(): ), f" sentence_cut != {paragraph_data.cut_type = }" # Test XLSX - for paragraph_data in xlsx_document.read(chunk_size=1024): + for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert ( diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py index 728b5cda..53098fc6 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_2_test.py @@ -27,7 +27,11 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), ) def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): - chunks = list(chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs)) + chunks = list( + chunk_by_paragraph( + data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + ) + ) chunk_lengths = np.array([len(list(chunk_by_word(chunk["text"]))) for chunk in chunks]) @@ -42,7 +46,9 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), ) def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): - chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) + chunks = chunk_by_paragraph( + data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs + ) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) assert np.all( chunk_indices == np.arange(len(chunk_indices)) diff --git a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py index 3ddc6f4f..e7d9a54b 100644 --- a/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py +++ b/cognee/tests/unit/processing/chunks/chunk_by_paragraph_test.py @@ -49,7 +49,9 @@ def run_chunking_test(test_text, expected_chunks): chunks = [] - for chunk_data in chunk_by_paragraph(test_text, 12, batch_paragraphs=False): + for chunk_data in chunk_by_paragraph( + data=test_text, paragraph_length=12, batch_paragraphs=False + ): chunks.append(chunk_data) assert len(chunks) == 3 diff --git a/evals/eval_swe_bench.py b/evals/eval_swe_bench.py index 789c95ab..20e00575 100644 --- a/evals/eval_swe_bench.py +++ b/evals/eval_swe_bench.py @@ -34,9 +34,8 @@ def check_install_package(package_name): async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS): repo_path = download_github_repo(instance, "../RAW_GIT_REPOS") - pipeline = await run_code_graph_pipeline(repo_path) - async for result in pipeline: + async for result in run_code_graph_pipeline(repo_path, include_docs=True): print(result) print("Here we have the repo under the repo_path") @@ -47,7 +46,9 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp instructions = read_query_prompt("patch_gen_kg_instructions.txt") retrieved_edges = await brute_force_triplet_search( - problem_statement, top_k=3, collections=["data_point_source_code", "data_point_text"] + problem_statement, + top_k=3, + collections=["code_summary_text"], ) retrieved_edges_str = retrieved_edges_to_string(retrieved_edges) diff --git a/examples/python/code_graph_example.py b/examples/python/code_graph_example.py index c90a0b60..59229344 100644 --- a/examples/python/code_graph_example.py +++ b/examples/python/code_graph_example.py @@ -1,7 +1,9 @@ import argparse import asyncio +import logging from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline +from cognee.shared.utils import setup_logging async def main(repo_path, include_docs): @@ -9,7 +11,7 @@ async def main(repo_path, include_docs): print(result) -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository") parser.add_argument( @@ -18,5 +20,28 @@ async def main(repo_path, include_docs): default=True, help="Whether or not to process non-code files", ) - args = parser.parse_args() - asyncio.run(main(args.repo_path, args.include_docs)) + parser.add_argument( + "--time", + type=lambda x: x.lower() in ("true", "1"), + default=True, + help="Whether or not to time the pipeline run", + ) + return parser.parse_args() + + +if __name__ == "__main__": + setup_logging(logging.ERROR) + + args = parse_args() + + if args.time: + import time + + start_time = time.time() + asyncio.run(main(args.repo_path, args.include_docs)) + end_time = time.time() + print("\n" + "=" * 50) + print(f"Pipeline Execution Time: {end_time - start_time:.2f} seconds") + print("=" * 50 + "\n") + else: + asyncio.run(main(args.repo_path, args.include_docs))