From 64816bea00c587bb53f31b59e1be0327f84e1592 Mon Sep 17 00:00:00 2001 From: jfrverdasca Date: Fri, 24 Jan 2025 11:18:02 +0000 Subject: [PATCH] Fix python vectorizer, prompt and other little things (#191) --- .../vectorizers/chunk_vectorizer.py | 16 ++- .../vectorizers/python_vectorizer.py | 66 ++++++----- labs/file_handler.py | 7 +- labs/llm/context.py | 56 +++++++--- labs/llm/prompt.py | 4 +- labs/parsers/python.py | 9 +- labs/tasks/llm.py | 5 +- labs/tasks/logging.py | 3 +- labs/tasks/repository.py | 10 +- labs/tasks/run.py | 5 +- labs/tests/test_file_handler.py | 103 ++++++++++++++++++ 11 files changed, 204 insertions(+), 80 deletions(-) create mode 100644 labs/tests/test_file_handler.py diff --git a/labs/embeddings/vectorizers/chunk_vectorizer.py b/labs/embeddings/vectorizers/chunk_vectorizer.py index 120eb37..2e7c3aa 100644 --- a/labs/embeddings/vectorizers/chunk_vectorizer.py +++ b/labs/embeddings/vectorizers/chunk_vectorizer.py @@ -64,12 +64,20 @@ def split_docs(self, docs): def vectorize_to_database(self, include_file_extensions, repository_path, *args, **kwargs): logger.debug("Loading and splitting all documents into chunks.") docs = self.load_docs(repository_path, include_file_extensions) - texts = self.split_docs(docs) - files_texts = [(text.metadata["source"], text.page_content) for text in texts] - texts = [files_text[1] for files_text in files_texts] + docs_parts = self.split_docs(docs) - logger.debug("Embedding all repository documents.") + texts = [] + files_texts = [] + for part in docs_parts: + # Add the file path to the content to allow retrieving the file by its path + # if it is mentioned in the prompt. + file_path = part.metadata["source"] + content = f"{file_path}\n\n{part.page_content}" + + files_texts.append((file_path, content)) + texts.append(content) + logger.debug("Embedding all repository documents.") embeddings = self.embedder.embed(prompt=texts) logger.debug("Storing all embeddings.") diff --git a/labs/embeddings/vectorizers/python_vectorizer.py b/labs/embeddings/vectorizers/python_vectorizer.py index 22e7bcc..eb92cf6 100644 --- a/labs/embeddings/vectorizers/python_vectorizer.py +++ b/labs/embeddings/vectorizers/python_vectorizer.py @@ -1,6 +1,8 @@ import logging import os +from pathlib import Path from types import SimpleNamespace +from typing import Union import pathspec from langchain_community.document_loaders import TextLoader @@ -14,24 +16,13 @@ class PythonVectorizer: def __init__(self, embedder): self.embedder = embedder - def prepare_doc_content(self, metadata, code_snippet): - metadata = SimpleNamespace(**metadata) + def load_file_chunks(self, file_path: Union[str, Path]): + try: + loader = TextLoader(file_path, encoding="utf-8") + return loader.load_and_split() - result = ( - f"Source: {metadata.source}\n" - f"Name: {metadata.name}\n" - f"Start line: {metadata.start_line}\n" - f"End line: {metadata.end_line}\n" - ) - - if hasattr(metadata, "parameters"): - result += f"Parameters: {', '.join(metadata.parameters)}\n" - - if hasattr(metadata, "returns"): - result += f"Returns: {metadata.returns}\n" - - result += f"\n\n{code_snippet}" - return result + except Exception: + logger.exception(f"Error loading file {file_path}") def load_docs(self, root_dir, file_extensions=None): docs = [] @@ -58,20 +49,20 @@ def load_docs(self, root_dir, file_extensions=None): if file_extensions and os.path.splitext(file_path)[1] not in file_extensions: continue - # only python files if os.path.splitext(file_path)[1] != ".py": - try: - loader = TextLoader(file_path, encoding="utf-8") - docs.extend(loader.load_and_split()) + docs.extend(self.load_file_chunks(file_path)) + continue - except Exception: - logger.exception("Failed to load repository documents into memory.") + # Python files with syntax errors will be loaded in chunks + try: + python_file_structure = parse_python_file(file_path) + except SyntaxError as e: + logger.error(f"Syntax error at {file_path}. File will be loaded without Python parsing: {e}") + docs.extend(self.load_file_chunks(file_path)) continue - python_file_structure = parse_python_file(file_path) - - # functions + # Functions for func in python_file_structure.get("functions", []): func_ns = SimpleNamespace(**func) @@ -85,10 +76,9 @@ def load_docs(self, root_dir, file_extensions=None): returns=func_ns.returns, ) - doc_content = self.prepare_doc_content(metadata, function_snippet) - docs.append(Document(doc_content, metadata=metadata)) + docs.append(Document(function_snippet, metadata=metadata)) - # classes + # Classes for cls in python_file_structure.get("classes", []): cls_ns = SimpleNamespace(**cls) @@ -100,8 +90,7 @@ def load_docs(self, root_dir, file_extensions=None): end_line=cls_ns.end_line, ) - doc_content = self.prepare_doc_content(metadata, class_snippet) - docs.append(Document(doc_content, metadata=metadata)) + docs.append(Document(class_snippet, metadata=metadata)) for method in cls.get("methods"): method_ns = SimpleNamespace(**method) @@ -116,8 +105,7 @@ def load_docs(self, root_dir, file_extensions=None): returns=method_ns.returns, ) - doc_content = self.prepare_doc_content(metadata, method_snippet) - docs.append(Document(doc_content, metadata=metadata)) + docs.append(Document(method_snippet, metadata=metadata)) return docs @@ -125,8 +113,16 @@ def vectorize_to_database(self, include_file_extensions, repository_path, *args, docs = self.load_docs(repository_path, include_file_extensions) logger.debug(f"Loading {len(docs)} documents...") - files_texts = [(doc.metadata["source"], doc.page_content) for doc in docs] - texts = [file_and_text[1] for file_and_text in files_texts] + texts = [] + files_texts = [] + for doc in docs: + # Add the file path to the content to allow retrieving the file by its path + # if it is mentioned in the prompt. + file_path = doc.metadata["source"] + content = f"{file_path}\n\n{doc.page_content}" + + files_texts.append((file_path, content)) + texts.append(content) embeddings = self.embedder.embed(prompt=texts) diff --git a/labs/file_handler.py b/labs/file_handler.py index 8f45652..a9557ef 100644 --- a/labs/file_handler.py +++ b/labs/file_handler.py @@ -51,8 +51,9 @@ def modify_file_line(file_path: str, content: Union[str | List[str]], line_numbe temp_file.write(line) - if line_number > current_line_number: - temp_file.writelines(content) + if not overwrite and line_number > current_line_number: + # If we reach this condition line_number + 1 does not apply, so we add a `\n` before + temp_file.writelines(f"\n{content}") except Exception as e: logger.error(f"Error modifying file {file_path}: {e}") @@ -85,7 +86,7 @@ def get_file_content(file_path: str) -> str: content = "" with open(file_path, "r") as file: for line_number, line in enumerate(file, start=1): - content += f"{line_number}: {line}" + content += f"{str(line_number).rjust(6)} | {line}" return content diff --git a/labs/llm/context.py b/labs/llm/context.py index 55f527f..6c2b576 100644 --- a/labs/llm/context.py +++ b/labs/llm/context.py @@ -13,24 +13,46 @@ "text/x-python": "python", } -CONTENT_TEMPLATE = "The following is the code in `{file}`:\n\n````{mimetype}\n{content}\n```" +CONTENT_TEMPLATE = "The following is the code in `{file}`:\n\n```{mimetype}\n{content}\n```" PERSONA_CONTEXT = """ -You are an advanced software engineer assistant designed to resolve code-based tasks. -You will receive: - 1. A description of the task. - 2. File names and their contents as context. - 3. Constraints such as not modifying migrations unless explicitly required. - -You should: - - Analyze the provided task description and associated context. - - Generate the necessary code changes to resolve the task. - - Ensure adherence to best practices for the programming language used. - - Avoid changes to migrations or unrelated files unless specified. - - Provide clean, organized, and ready-to-review code changes. - - Group related logic together to ensure clarity and cohesion. - - Add meaningful comments to explain non-obvious logic or complex operations. - - Ensure the code integrates seamlessly into the existing structure of the project. - - Perform the 'delete' operations in **reverse line number order** to avoid line shifting. +# Overview +This is a Python repository for Revent, a photo contest API built with Django. + +# Guidance +You are an advanced software engineering assistant tasked with resolving code-related tasks. You will receive: +- Task descriptions. +- File names with line numbers and content as context. +- Constraints such as not modifying migrations unless explicitly required. + +Your environment is fully set up—no need to install packages. + +# Task Guidelines +- Solve problems with minimal, clean, and efficient code. +- Avoid complexity: minimize branching logic, error handling, and unnecessary lines. + +# Code Style +- Use Python's standard library/core packages when possible. +- Prioritize readability, maintainability, and computational efficiency. +- Avoid excessive loops, chains, and nested logic. +- Use: + - Explicit imports (no `import *`). + - List comprehensions over loops (but keep them simple). + - f-strings for formatting. + - Type hints for all function signatures. + - Dataclasses for simple classes. + - Avoid: + - Excessive try/except blocks or nested error handling. + - Installing new packages or using make commands unless specified. + +# Testing +- Place tests in the `tests` directory. +- Use the `unittest` framework. +- Write unit tests for all functions, covering edge cases and error conditions. +- Aim for 100% test coverage with concise, comprehensive tests. + +# Resources +- README.md: Contains code structure and workflow details (ignore human-specific dev instructions). +- Makefile: Useful commands (refer to instructions in this file). """ diff --git a/labs/llm/prompt.py b/labs/llm/prompt.py index cd5e382..c17facd 100644 --- a/labs/llm/prompt.py +++ b/labs/llm/prompt.py @@ -39,9 +39,9 @@ { "steps": [ { - "type": "Operation type: 'create', 'update', 'overwrite', or 'delete'", + "type": "Operation type: 'create', 'insert', 'overwrite', or 'delete'", "path": "Absolute file path", - "content": "Content to write (required for 'create', 'update', or 'overwrite')", + "content": "Content to write (required for 'create', 'insert', or 'overwrite')", "line": "Initial line number where the content should be written (or erased if 'delete')", } ] diff --git a/labs/parsers/python.py b/labs/parsers/python.py index 1a03dea..ee30d3a 100644 --- a/labs/parsers/python.py +++ b/labs/parsers/python.py @@ -209,14 +209,9 @@ def parse_python_file(file_path: str) -> str | dict: file_content = source.read() parser = PythonFileParser(file_name=file_path) - try: - tree = ast.parse(file_content, file_path) - - except SyntaxError as e: - print(f"Syntax error at {file_path}: {e}") - return dict() - + tree = ast.parse(file_content, file_path) parser.visit(tree) + return parser.get_structure() diff --git a/labs/tasks/llm.py b/labs/tasks/llm.py index 3c0b75a..e878364 100644 --- a/labs/tasks/llm.py +++ b/labs/tasks/llm.py @@ -2,6 +2,8 @@ import logging from typing import List, Optional +from config.celery import app +from config.redis_client import RedisVariable, redis_client from core.models import Model, VectorizerModel from django.conf import settings from embeddings.embedder import Embedder @@ -11,9 +13,6 @@ from llm.prompt import get_prompt from llm.requester import Requester -from config.celery import app -from config.redis_client import RedisVariable, redis_client - logger = logging.getLogger(__name__) diff --git a/labs/tasks/logging.py b/labs/tasks/logging.py index f995768..ed7098a 100644 --- a/labs/tasks/logging.py +++ b/labs/tasks/logging.py @@ -1,7 +1,6 @@ -from core.models import Model, WorkflowResult - from config.celery import app from config.redis_client import RedisVariable, redis_client +from core.models import Model, WorkflowResult @app.task diff --git a/labs/tasks/repository.py b/labs/tasks/repository.py index 80fd3c5..e69331b 100644 --- a/labs/tasks/repository.py +++ b/labs/tasks/repository.py @@ -2,14 +2,13 @@ import logging from typing import List, cast +from config.celery import app +from config.redis_client import RedisVariable, redis_client from decorators import time_and_log_function from file_handler import create_file, delete_file_line, modify_file_line from github.github import GithubRequests from parsers.response import parse_llm_output -from config.celery import app -from config.redis_client import RedisVariable, redis_client - logger = logging.getLogger(__name__) @@ -26,8 +25,11 @@ def github_repository_data(prefix, token="", repository_owner="", repository_nam def apply_code_changes(llm_response): response = parse_llm_output(llm_response) + # We will sort the operations by file and by line number + sorted_steps = sorted(response.steps, key=lambda s: (s.path, s.line), reverse=True) + files: List[str | None] = [] - for step in response.steps: + for step in sorted_steps: match step.type: case "create": create_file(step.path, step.content) diff --git a/labs/tasks/run.py b/labs/tasks/run.py index 24f690a..2f30959 100644 --- a/labs/tasks/run.py +++ b/labs/tasks/run.py @@ -1,6 +1,8 @@ import os.path from celery import chain +from config.celery import app +from config.redis_client import RedisVariable, redis_client from tasks import ( apply_code_changes_task, clone_repository_task, @@ -15,9 +17,6 @@ vectorize_repository_task, ) -from config.celery import app -from config.redis_client import RedisVariable, redis_client - @app.task(bind=True) def init_task(self, **kwargs): diff --git a/labs/tests/test_file_handler.py b/labs/tests/test_file_handler.py new file mode 100644 index 0000000..37bd696 --- /dev/null +++ b/labs/tests/test_file_handler.py @@ -0,0 +1,103 @@ +import os +from tempfile import NamedTemporaryFile +from unittest import TestCase + +from file_handler import create_file, delete_file_line, get_file_content, modify_file_line + + +class TestFileHandler(TestCase): + TEMPORARY_CONTENT = """ + Line 1 + Line 2 + Line 4 + """ + NEW_CONTENT = " Line 3\n" + # Read variables + READ_FILE_EXPECTED_CONTENT = """ 1 | + 2 | Line 1 + 3 | Line 2 + 4 | Line 4 + 5 | """ + # Insert variables + INSERT_LINE = 3 + INSERT_LINE_EXPECTED_CONTENT = """ + Line 1 + Line 2 + Line 3 + Line 4 + """ + INSERT_EOF_LINE = 6 + INSERT_EOF_EXPECTED_CONTENT = """ + Line 1 + Line 2 + Line 4 + + Line 3 +""" + # Overwrite variables + OVERWRITE_LINE = 4 + OVERWRITE_LINE_EXPECTED_CONTENT = """ + Line 1 + Line 2 + Line 3 + """ + OVERWRITE_EOF_LINE = 6 + OVERWRITE_EOF_EXPECTED_CONTENT = """ + Line 1 + Line 2 + Line 4 + """ + # Delete variables + DELETE_LINE = 4 + DELETE_LINE_EXPECTED_CONTENT = """ + Line 1 + Line 2 + """ + + def assertFileContentEqual(self, file_path, expected_content): + with open(file_path, "r") as file_handler: + self.assertEqual(expected_content, file_handler.read()) + + def setUp(self): + self.temporary_file = NamedTemporaryFile(mode="w+", delete=False) + self.temporary_file.write(self.TEMPORARY_CONTENT) + self.temporary_file.flush() + self.temporary_file.close() + + def test_get_file_content(self): + content = get_file_content(self.temporary_file.name) + self.assertEqual(content, self.READ_FILE_EXPECTED_CONTENT) + + def test_create_file(self): + # Remove the existing file created in setUp + try: + os.remove(self.temporary_file.name) + + except FileNotFoundError: + pass + + create_file(self.temporary_file.name, self.TEMPORARY_CONTENT) + self.assertFileContentEqual(self.temporary_file.name, self.TEMPORARY_CONTENT) + + def test_insert_file_line(self): + modify_file_line(self.temporary_file.name, self.NEW_CONTENT, self.INSERT_LINE) + self.assertFileContentEqual(self.temporary_file.name, self.INSERT_LINE_EXPECTED_CONTENT) + + def test_insert_end_of_file(self): + modify_file_line(self.temporary_file.name, self.NEW_CONTENT, self.INSERT_EOF_LINE) + self.assertFileContentEqual(self.temporary_file.name, self.INSERT_EOF_EXPECTED_CONTENT) + + def test_overwrite_file_line(self): + modify_file_line(self.temporary_file.name, self.NEW_CONTENT, self.OVERWRITE_LINE, overwrite=True) + self.assertFileContentEqual(self.temporary_file.name, self.OVERWRITE_LINE_EXPECTED_CONTENT) + + def test_overwrite_end_of_file(self): + modify_file_line(self.temporary_file.name, self.NEW_CONTENT, self.OVERWRITE_EOF_LINE, overwrite=True) + self.assertFileContentEqual(self.temporary_file.name, self.OVERWRITE_EOF_EXPECTED_CONTENT) + + def test_delete_file_line(self): + delete_file_line(self.temporary_file.name, self.DELETE_LINE) + self.assertFileContentEqual(self.temporary_file.name, self.DELETE_LINE_EXPECTED_CONTENT) + + def tearDown(self): + os.remove(self.temporary_file.name)