Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cog 813 source code chunks #383

Merged
merged 19 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions cognee/api/v1/cognify/code_graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path

from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import \
get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
Expand All @@ -15,8 +17,10 @@
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_data_list_for_user,
get_non_code_files,
get_non_py_files,
get_repo_file_dependencies)
from cognee.tasks.repo_processor.get_source_code_chunks import \
get_source_code_chunks
from cognee.tasks.storage import add_data_points

monitoring = get_base_config().monitoring_tool
Expand All @@ -28,6 +32,7 @@
logger = logging.getLogger("code_graph_pipeline")
update_status_lock = asyncio.Lock()


@observe
async def run_code_graph_pipeline(repo_path, include_docs=True):
import os
Expand All @@ -46,20 +51,23 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()

embedding_engine = get_embedding_engine()

lxobr marked this conversation as resolved.
Show resolved Hide resolved
cognee_config = get_cognify_config()
user = await get_default_user()

tasks = [
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph, task_config={"batch_size": 50}),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]

if include_docs:
non_code_tasks = [
Task(get_non_code_files, task_config={"batch_size": 50}),
Task(get_non_py_files, task_config={"batch_size": 50}),
Task(ingest_data_with_metadata, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents),
Expand All @@ -71,7 +79,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
task_config={"batch_size": 50}
),
]

if include_docs:
async for result in run_tasks(non_code_tasks, repo_path):
yield result
Expand Down
27 changes: 14 additions & 13 deletions cognee/shared/CodeGraphEntities.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import List, Optional

from cognee.infrastructure.engine import DataPoint


class Repository(DataPoint):
__tablename__ = "Repository"
path: str
_metadata: dict = {
"index_fields": ["source_code"],
lxobr marked this conversation as resolved.
Show resolved Hide resolved
"index_fields": [],
"type": "Repository"
}

Expand All @@ -19,29 +18,31 @@ class CodeFile(DataPoint):
depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None

_metadata: dict = {
"index_fields": ["source_code"],
"index_fields": [],
"type": "CodeFile"
}

class CodePart(DataPoint):
__tablename__ = "codepart"
# part_of: Optional[CodeFile]
source_code: str

# part_of: Optional[CodeFile] = None
source_code: Optional[str] = None
lxobr marked this conversation as resolved.
Show resolved Hide resolved
_metadata: dict = {
"index_fields": ["source_code"],
"index_fields": [],
"type": "CodePart"
}

class CodeRelationship(DataPoint):
source_id: str
target_id: str
relation: str # depends on or depends directly
class SourceCodeChunk(DataPoint):
__tablename__ = "sourcecodechunk"
code_chunk_of: Optional[CodePart] = None
source_code: Optional[str] = None
previous_chunk: Optional["SourceCodeChunk"] = None

_metadata: dict = {
"type": "CodeRelationship"
"index_fields": ["source_code"],
"type": "SourceCodeChunk"
}

CodeFile.model_rebuild()
CodePart.model_rebuild()
SourceCodeChunk.model_rebuild()
1 change: 0 additions & 1 deletion cognee/shared/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ class SummarizedClass(BaseModel):
decorators: Optional[List[str]] = None

class SummarizedCode(BaseModel):
file_name: str
high_level_summary: str
key_features: List[str]
imports: List[str] = []
Expand Down
9 changes: 6 additions & 3 deletions cognee/tasks/repo_processor/get_repo_file_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
path = repo_path,
)

yield repo
yield [repo]

with ProcessPoolExecutor(max_workers = 12) as executor:
loop = asyncio.get_event_loop()
Expand All @@ -90,10 +90,11 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non

results = await asyncio.gather(*tasks)

code_files = []
lxobr marked this conversation as resolved.
Show resolved Hide resolved
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code")

yield CodeFile(
code_files.append(CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
Expand All @@ -106,4 +107,6 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
source_code = py_files_dict.get(dependency, {}).get("source_code"),
) for dependency in dependencies
] if dependencies else None,
)
))

yield code_files
157 changes: 157 additions & 0 deletions cognee/tasks/repo_processor/get_source_code_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import logging
from typing import AsyncGenerator, Generator
from uuid import NAMESPACE_OID, uuid5

import parso
import tiktoken

from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk

logger = logging.getLogger("task:get_source_code_chunks")


def _count_tokens(tokenizer: tiktoken.Encoding, source_code: str) -> int:
return len(tokenizer.encode(source_code))


def _get_naive_subchunk_token_counts(
tokenizer: tiktoken.Encoding, source_code: str, max_subchunk_tokens: int = 8000
) -> list[tuple[str, int]]:
"""Splits source code into subchunks of up to max_subchunk_tokens and counts tokens."""

token_ids = tokenizer.encode(source_code)
subchunk_token_counts = []

for start_idx in range(0, len(token_ids), max_subchunk_tokens):
subchunk_token_ids = token_ids[start_idx: start_idx + max_subchunk_tokens]
token_count = len(subchunk_token_ids)
subchunk = ''.join(
tokenizer.decode_single_token_bytes(token_id).decode('utf-8', errors='replace')
for token_id in subchunk_token_ids
)
subchunk_token_counts.append((subchunk, token_count))

return subchunk_token_counts


def _get_subchunk_token_counts(
tokenizer: tiktoken.Encoding,
source_code: str,
max_subchunk_tokens: int = 8000,
depth: int = 0,
max_depth: int = 100
) -> list[tuple[str, int]]:
"""Splits source code into subchunk and counts tokens for each subchunk."""
if depth > max_depth:
return _get_naive_subchunk_token_counts(tokenizer, source_code, max_subchunk_tokens)

try:
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return []

if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return []

# Handle cases with only one real child and an EndMarker to prevent infinite recursion.
if len(module.children) <= 2:
lxobr marked this conversation as resolved.
Show resolved Hide resolved
module = module.children[0]

subchunk_token_counts = []
for child in module.children:
subchunk = child.get_code()
token_count = _count_tokens(tokenizer, subchunk)

if token_count == 0:
continue

if token_count <= max_subchunk_tokens:
subchunk_token_counts.append((subchunk, token_count))
continue

if child.type == 'string':
subchunk_token_counts.extend(_get_naive_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens))
continue

subchunk_token_counts.extend(
_get_subchunk_token_counts(tokenizer, subchunk, max_subchunk_tokens, depth=depth + 1, max_depth=max_depth)
)

return subchunk_token_counts


def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int
) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling."""
lxobr marked this conversation as resolved.
Show resolved Hide resolved
current_count = 0
cumulative_counts = []
current_source_code = ''

for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count
cumulative_counts.append(current_count)
if current_count > max_tokens:
break
current_source_code += f"\n{child_code}"

if current_count <= max_tokens:
return [], current_source_code.strip()

cutoff = 1
for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * max_tokens:
break
cutoff = i

return code_token_counts[cutoff:], current_source_code.strip()


def get_source_code_chunks_from_code_part(
code_file_part: CodePart,
max_tokens: int = 8192,
overlap: float = 0.25,
lxobr marked this conversation as resolved.
Show resolved Hide resolved
granularity: float = 0.1,
model_name: str = "text-embedding-3-large"
) -> Generator[SourceCodeChunk, None, None]:
lxobr marked this conversation as resolved.
Show resolved Hide resolved
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}")
return

tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)

previous_chunk = None
while subchunk_token_counts:
subchunk_token_counts, chunk_source_code = _get_chunk_source_code(subchunk_token_counts, overlap, max_tokens)
if not chunk_source_code:
continue
current_chunk = SourceCodeChunk(
id=uuid5(NAMESPACE_OID, chunk_source_code),
code_chunk_of=code_file_part,
source_code=chunk_source_code,
previous_chunk=previous_chunk
)
yield current_chunk
previous_chunk = current_chunk


async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
for data_point in data_points:
yield data_point
if not isinstance(data_point, CodeFile):
continue
if not data_point.contains:
continue
for code_part in data_point.contains:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
lxobr marked this conversation as resolved.
Show resolved Hide resolved
yield source_code_chunk
lxobr marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 7 additions & 1 deletion cognee/tasks/storage/index_data_points.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from litellm.exceptions import BadRequestError
from litellm.llms.OpenAI.openai import OpenAIError

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint

Expand Down Expand Up @@ -30,7 +33,10 @@ async def index_data_points(data_points: list[DataPoint]):

for index_name, indexable_points in index_points.items():
index_name, field_name = index_name.split(".")
await vector_engine.index_data_points(index_name, field_name, indexable_points)
try:
await vector_engine.index_data_points(index_name, field_name, indexable_points)
except (OpenAIError, BadRequestError) as e:
lxobr marked this conversation as resolved.
Show resolved Hide resolved
print(f"Failed to index data points for {index_name}.{field_name}: {e}")

return data_points

Expand Down
6 changes: 4 additions & 2 deletions cognee/tasks/summarization/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union

from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models import DocumentChunk
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk


class TextSummary(DataPoint):
Expand All @@ -17,7 +19,7 @@ class TextSummary(DataPoint):
class CodeSummary(DataPoint):
__tablename__ = "code_summary"
text: str
made_from: CodeFile
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]

_metadata: dict = {
"index_fields": ["text"],
Expand Down
8 changes: 4 additions & 4 deletions cognee/tasks/summarization/summarize_code.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
from typing import AsyncGenerator, Union
from uuid import uuid5
from typing import Type

from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_code_summary

from .models import CodeSummary


Expand All @@ -21,7 +21,7 @@ async def summarize_code(
)

file_summaries_map = {
code_data_point.extracted_id: str(file_summary)
code_data_point.id: str(file_summary)
for code_data_point, file_summary in zip(code_data_points, file_summaries)
}

Expand All @@ -35,6 +35,6 @@ async def summarize_code(

yield CodeSummary(
id=uuid5(node.id, "CodeSummary"),
made_from=node,
text=file_summaries_map[node.extracted_id],
summarizes=node,
text=file_summaries_map[node.id],
)
2 changes: 1 addition & 1 deletion examples/python/code_graph_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ async def main(repo_path, include_docs):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--repo_path", type=str, required=True, help="Path to the repository")
parser.add_argument("--include_docs", type=bool, default=True, help="Whether or not to process non-code files")
parser.add_argument("--include_docs", type=lambda x: x.lower() in ("true", "1"), default=True, help="Whether or not to process non-code files")
args = parser.parse_args()
asyncio.run(main(args.repo_path, args.include_docs))
Loading