Skip to content

Commit

Permalink
Get embedding engine instead of passing it in code chunking.
Browse files Browse the repository at this point in the history
  • Loading branch information
alekszievr committed Jan 8, 2025
1 parent 34a9267 commit 97814e3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
6 changes: 1 addition & 5 deletions cognee/api/v1/cognify/code_graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pathlib import Path

from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.embeddings import \
get_embedding_engine
from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.pipelines import run_tasks
from cognee.modules.pipelines.tasks.Task import Task
Expand Down Expand Up @@ -51,16 +49,14 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
await cognee.prune.prune_system(metadata=True)
await create_db_and_tables()

embedding_engine = get_embedding_engine()

cognee_config = get_cognify_config()
user = await get_default_user()

tasks = [
Task(get_repo_file_dependencies),
Task(enrich_dependency_graph),
Task(expand_dependency_graph, task_config={"batch_size": 50}),
Task(get_source_code_chunks, embedding_model=embedding_engine.model, task_config={"batch_size": 50}),
Task(get_source_code_chunks, task_config={"batch_size": 50}),
Task(summarize_code, task_config={"batch_size": 50}),
Task(add_data_points, task_config={"batch_size": 50}),
]
Expand Down
4 changes: 2 additions & 2 deletions cognee/tasks/chunks/chunk_by_paragraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def chunk_by_paragraph(

vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model

embedding_model = embedding_model.split("/")[-1]

for paragraph_id, sentence, word_count, end_type in chunk_by_sentence(data, maximum_length=paragraph_length):
# Check if this sentence would exceed length limit

embedding_model = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(embedding_model)
token_count = len(tokenizer.encode(sentence))

Expand Down
9 changes: 6 additions & 3 deletions cognee/tasks/repo_processor/get_source_code_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import parso
import tiktoken

from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart, SourceCodeChunk

Expand Down Expand Up @@ -115,13 +116,15 @@ def get_source_code_chunks_from_code_part(
max_tokens: int = 8192,
overlap: float = 0.25,
granularity: float = 0.1,
model_name: str = "text-embedding-3-large"
) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}")
return

vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name)
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts(tokenizer, code_file_part.source_code, max_subchunk_tokens)
Expand All @@ -141,7 +144,7 @@ def get_source_code_chunks_from_code_part(
previous_chunk = current_chunk


async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="text-embedding-3-large") -> \
async def get_source_code_chunks(data_points: list[DataPoint]) -> \
AsyncGenerator[list[DataPoint], None]:
"""Processes code graph datapoints, create SourceCodeChink datapoints."""
# TODO: Add support for other embedding models, with max_token mapping
Expand All @@ -156,7 +159,7 @@ async def get_source_code_chunks(data_points: list[DataPoint], embedding_model="
for code_part in data_point.contains:
try:
yield code_part
for source_code_chunk in get_source_code_chunks_from_code_part(code_part, model_name=embedding_model):
for source_code_chunk in get_source_code_chunks_from_code_part(code_part):
yield source_code_chunk
except Exception as e:
logger.error(f"Error processing code part: {e}")
Expand Down

0 comments on commit 97814e3

Please sign in to comment.