diff --git a/.env.sample b/.env.sample index fac3dc3..e125f52 100644 --- a/.env.sample +++ b/.env.sample @@ -1,6 +1,6 @@ -DJANGO_DEBUG=True -DJANGO_ALLOWED_HOSTS=localhost -DJANGO_SECRET_KEY= +DEBUG=True +ALLOWED_HOSTS=127.0.0.1,localhost +SECRET_KEY= DATABASE_HOST=localhost DATABASE_USER=postgres diff --git a/docker-compose.yml b/docker-compose.yml index acee052..87dfda7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,9 +15,6 @@ services: build: context: . dockerfile: ./Dockerfile - environment: - - CELERY_BROKER_URL=redis://redis:6379/0 - - CELERY_RESULT_BD=redis://redis:6379/0 ports: - "8000:8000" depends_on: @@ -54,11 +51,6 @@ services: context: . dockerfile: ./Dockerfile command: celery -A config worker --beat --scheduler redbeat.RedBeatScheduler --loglevel=INFO --concurrency=1 - environment: - - CELERY_BROKER_URL=redis://redis:6379/0 - - CELERY_RESULT_BACKEND=redis://redis:6379/0 - - CELERY_TASK_ALWAYS_EAGER=False - - DATABASE_URL=postgresql://postgres:postgres@labs-db:5432/postgres depends_on: - redis - labs-db diff --git a/docs/diagram/class_diagram b/docs/diagram/class_diagram index a18d478..7702230 100644 --- a/docs/diagram/class_diagram +++ b/docs/diagram/class_diagram @@ -12,8 +12,8 @@ digraph { color=lightgrey label="./code_examples/calculator.py" style=filled Calculator [label="{ Calculator | + add(self, x, y)\l+ subtract(self, x, y)\l }" shape=record] } - subgraph "cluster_./litellm_service/request.py" { - color=lightgrey label="./litellm_service/request.py" style=filled + subgraph "cluster_./llm/request.py" { + color=lightgrey label="./llm/request.py" style=filled RequestLiteLLM [label="{ RequestLiteLLM | + __init__(self, litellm_api_key)\l+ completion(self, messages, model)\l+ completion_without_proxy(self, messages, model)\l }" shape=record] } } diff --git a/labs/api/codemonkey_endpoints.py b/labs/api/codemonkey_endpoints.py index 1cd5c42..949c1a6 100644 --- a/labs/api/codemonkey_endpoints.py +++ b/labs/api/codemonkey_endpoints.py @@ -56,7 +56,7 @@ async def run_on_repository_endpoint(request: HttpRequest, run_on_repository: Gi async def run_on_local_repository_endpoint(request: HttpRequest, run_on_local_repository: LocalRepositoryShema): try: run_on_local_repository_task( - repository_path=run_on_local_repository.repository_path, issue_text=run_on_local_repository.prompt + repository_path=run_on_local_repository.repository_path, issue_body=run_on_local_repository.prompt ) except Exception as ex: logger.exception("Internal server error") diff --git a/labs/api/github_endpoints.py b/labs/api/github_endpoints.py index 61517d2..4c52d6d 100644 --- a/labs/api/github_endpoints.py +++ b/labs/api/github_endpoints.py @@ -25,7 +25,7 @@ async def list_issues_endpoint(request: HttpRequest, params: ListIssuesSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -41,7 +41,7 @@ async def list_issues_endpoint(request: HttpRequest, params: ListIssuesSchema): async def get_issue_endpoint(request: HttpRequest, params: IssueSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -57,7 +57,7 @@ async def get_issue_endpoint(request: HttpRequest, params: IssueSchema): async def create_branch_endpoint(request: HttpRequest, params: BranchSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -73,7 +73,7 @@ async def create_branch_endpoint(request: HttpRequest, params: BranchSchema): async def change_issue_status_endpoint(request: HttpRequest, params: IssueStatusSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -89,7 +89,7 @@ async def change_issue_status_endpoint(request: HttpRequest, params: IssueStatus async def commit_changes_endpoint(request: HttpRequest, params: CommitSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -107,7 +107,7 @@ async def commit_changes_endpoint(request: HttpRequest, params: CommitSchema): async def create_pull_request_endpoint(request: HttpRequest, params: PullRequestSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, @@ -125,7 +125,7 @@ async def create_pull_request_endpoint(request: HttpRequest, params: PullRequest async def clone_repository_endpoint(request: HttpRequest, params: GithubSchema): try: github_requests = GithubRequests( - github_token=params.token, + token=params.token, repository_owner=params.repository_owner, repository_name=params.repository_name, username=params.username, diff --git a/labs/core/models.py b/labs/core/models.py index f40caca..9925060 100644 --- a/labs/core/models.py +++ b/labs/core/models.py @@ -6,12 +6,12 @@ from embeddings.embedder import Embedder from embeddings.ollama import OllamaEmbedder from embeddings.openai import OpenAIEmbedder -from embeddings.vectorizers.base import Vectorizer from embeddings.vectorizers.chunk_vectorizer import ChunkVectorizer from embeddings.vectorizers.python_vectorizer import PythonVectorizer -from litellm_service.llm_requester import Requester -from litellm_service.ollama import OllamaRequester -from litellm_service.openai import OpenAIRequester +from embeddings.vectorizers.vectorizer import Vectorizer +from llm.ollama import OllamaRequester +from llm.openai import OpenAIRequester +from llm.requester import Requester provider_model_class = { "OPENAI": {"embedding": OpenAIEmbedder, "llm": OpenAIRequester}, diff --git a/labs/embeddings/embedder.py b/labs/embeddings/embedder.py index 0336c59..08a664c 100644 --- a/labs/embeddings/embedder.py +++ b/labs/embeddings/embedder.py @@ -21,7 +21,7 @@ def embed(self, prompt, *args, **kwargs) -> Embeddings: def retrieve_embeddings( self, query: str, repository: str, similarity_threshold: int = 0.7, number_of_results: int = 10 - ) -> List[Embeddings]: + ) -> List[Embedding]: query = query.replace("\n", "") embedded_query = self.embed(prompt=query).embeddings if not embedded_query: diff --git a/labs/embeddings/vectorizers/base.py b/labs/embeddings/vectorizers/vectorizer.py similarity index 100% rename from labs/embeddings/vectorizers/base.py rename to labs/embeddings/vectorizers/vectorizer.py diff --git a/labs/github/github.py b/labs/github/github.py index c856e11..98f3f6c 100644 --- a/labs/github/github.py +++ b/labs/github/github.py @@ -14,8 +14,8 @@ class GithubRequests: """Class to handle Github API requests""" - def __init__(self, github_token, repository_owner, repository_name, username=None): - self.github_token = github_token + def __init__(self, token, repository_owner, repository_name, username=None): + self.token = token self.repository_owner = repository_owner self.repository_name = repository_name self.username = username @@ -67,7 +67,7 @@ def list_issues(self, assignee=None, state="open", per_page=100): url = f"{self.github_api_url}/issues" headers = { - "Authorization": f"token {self.github_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", } params = { @@ -84,7 +84,7 @@ def get_issue(self, issue_number): # issue_number is the actual number of the issue, not the id. url = f"{self.github_api_url}/issues/{issue_number}" headers = { - "Authorization": f"token {self.github_token}", + "Authorization": f"token {self.token}", "Accept": "application/vnd.github.v3+json", } response_json, _ = self._get(url, headers, {}) @@ -93,7 +93,7 @@ def get_issue(self, issue_number): def create_branch(self, branch_name, original_branch="main"): url = f"{self.github_api_url}/git/refs/heads/{original_branch}" headers = { - "Authorization": f"Bearer {self.github_token}", + "Authorization": f"Bearer {self.token}", "X-Accepted-GitHub-Permissions": "contents=write", } response_json, status_code = self._get(url, headers=headers) @@ -110,7 +110,7 @@ def change_issue_status(self, issue_number, status): url = f"{self.github_api_url}/issues/{issue_number}" headers = { - "Authorization": f"token {self.github_token}", + "Authorization": f"token {self.token}", "user-agent": "request", } data = {"state": status} @@ -121,7 +121,7 @@ def commit_changes(self, message, branch_name, files): # Step 1: Get the latest commit SHA on the specified branch url = f"{self.github_api_url}/git/refs/heads/{branch_name}" headers = { - "Authorization": f"token {self.github_token}", + "Authorization": f"token {self.token}", "Content-Type": "application/json", } response_json, _ = self._get(url, headers) @@ -184,7 +184,7 @@ def commit_changes(self, message, branch_name, files): def create_pull_request(self, head, base="main", title="New Pull Request", body=""): url = f"{self.github_api_url}/pulls" - headers = {"Authorization": f"token {self.github_token}"} + headers = {"Authorization": f"token {self.token}"} data = {"title": title, "body": body, "head": head, "base": base} return self._post(url, headers, data) diff --git a/labs/llm.py b/labs/llm.py deleted file mode 100644 index 33c0126..0000000 --- a/labs/llm.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging - -from core.models import Model, VectorizerModel -from decorators import time_and_log_function -from embeddings.embedder import Embedder -from embeddings.vectorizers.base import Vectorizer -from litellm_service.llm_requester import Requester -from parsers.response_parser import is_valid_json, parse_llm_output - -logger = logging.getLogger(__name__) - - -def get_prompt(issue_summary): - return f""" - You're a diligent software engineer AI. You can't see, draw, or interact with a - browser, but you can read and write files, and you can think. - You've been given the following task: {issue_summary}. - Any imports will be at the beggining of the file. - Add tests for the new functionalities, considering any existing test files. - The file paths provided are **absolute paths relative to the project root**, - and **must not be changed**. Ensure the paths you output match the paths provided exactly. - Do not prepend or modify the paths. - Please provide a json response in the following format: {{"steps": [...]}} - Where steps is a list of objects where each object contains three fields: - type, which is either 'create' to add a new file or 'modify' to edit an existing one; - If the file is to be modified send the finished version of the entire file. - path, which is the absolute path of the file to create/modify; - content, which is the content to write to the file. - """ - - -def prepare_context(context, prompt): - prepared_context = [] - for file in context: - prepared_context.append( - { - "role": "system", - "content": f"File: {file[1]} Content: {file[2]}", - } - ) - prepared_context.append( - { - "role": "user", - "content": prompt, - } - ) - return prepared_context - - -def check_length_issue(llm_response): - finish_reason = getattr(llm_response["choices"][0]["message"], "finish_reason", None) - if finish_reason == "length": - return ( - True, - "Conversation was too long for the context window, resulting in incomplete JSON.", - ) - return False, "" - - -def check_content_filter(llm_response): - finish_reason = getattr(llm_response["choices"][0]["message"], "finish_reason", None) - if finish_reason == "content_filter": - return ( - True, - "Model's output included restricted content. Generation of JSON was halted and may be partial.", - ) - return False, "" - - -def check_refusal(llm_response): - refusal_reason = getattr(llm_response["choices"][0]["message"], "refusal", None) - if refusal_reason: - return ( - True, - f"OpenAI safety system refused the request. Reason: {refusal_reason}", - ) - return False, "" - - -def check_invalid_json_response(llm_response): - response_string = llm_response["choices"][0]["message"]["content"] - if not is_valid_json(response_string): - return True, "Invalid JSON response." - else: - if not parse_llm_output(response_string): - return True, "Invalid JSON response." - return False, "" - - -validation_checks = [ - check_length_issue, - check_content_filter, - check_refusal, - check_invalid_json_response, -] - - -def validate_llm_response(llm_response): - for check in validation_checks: - logger.debug(llm_response) - is_invalid, message = check(llm_response[1]) - if is_invalid: - return True, message - return False, "" - - -def get_llm_response(prepared_context): - llm_requester, *llm_requester_args = Model.get_active_llm_model() - - retries, max_retries = 0, 5 - redo, redo_reason = True, None - requester = Requester(llm_requester, *llm_requester_args) - - while redo and retries < max_retries: - try: - llm_response = requester.completion_without_proxy(prepared_context) - logger.debug(f"LLM Response: {llm_response}") - redo, redo_reason = validate_llm_response(llm_response) - except Exception: - redo, redo_reason = True, "Error calling LLM." - logger.exception(redo_reason) - - if redo: - retries = retries + 1 - logger.info(f"Redoing request due to {redo_reason}") - - if retries == max_retries: - logger.info("Max retries reached.") - return False, None - return True, llm_response - - -@time_and_log_function -def call_llm_with_context(repository_path, issue_summary): - if not issue_summary: - logger.error("issue_summary cannot be empty.") - raise ValueError("issue_summary cannot be empty.") - - embedder_class, *embeder_args = Model.get_active_embedding_model() - embedder = Embedder(embedder_class, *embeder_args) - - vectorizer_class = VectorizerModel.get_active_vectorizer() - Vectorizer(vectorizer_class, embedder).vectorize_to_database(None, repository_path) - - # find_similar_embeddings narrows down codebase to files that matter for the issue at hand. - context = embedder.retrieve_embeddings(issue_summary, repository_path) - - prompt = get_prompt(issue_summary) - prepared_context = prepare_context(context, prompt) - - logger.debug(f"Issue Summary: {issue_summary} - LLM Model: {embeder_args[0]}") - - return get_llm_response(prepared_context) diff --git a/labs/litellm_service/__init__.py b/labs/llm/__init__.py similarity index 100% rename from labs/litellm_service/__init__.py rename to labs/llm/__init__.py diff --git a/labs/litellm_service/ollama.py b/labs/llm/ollama.py similarity index 100% rename from labs/litellm_service/ollama.py rename to labs/llm/ollama.py diff --git a/labs/litellm_service/openai.py b/labs/llm/openai.py similarity index 100% rename from labs/litellm_service/openai.py rename to labs/llm/openai.py diff --git a/labs/litellm_service/llm_requester.py b/labs/llm/requester.py similarity index 100% rename from labs/litellm_service/llm_requester.py rename to labs/llm/requester.py diff --git a/labs/logger.py b/labs/logger.py index 63dbf16..476aabe 100644 --- a/labs/logger.py +++ b/labs/logger.py @@ -44,5 +44,6 @@ def setup_logger(): ) handler.setFormatter(formatter) logger.addHandler(handler) + except Exception: pass diff --git a/labs/parsers/response_parser.py b/labs/parsers/response.py similarity index 87% rename from labs/parsers/response_parser.py rename to labs/parsers/response.py index 65d2b18..b1656d7 100644 --- a/labs/parsers/response_parser.py +++ b/labs/parsers/response.py @@ -12,12 +12,12 @@ class Step(BaseModel): content: str -class PullRequest(BaseModel): +class Response(BaseModel): steps: list[Step] -def parse_llm_output(text_output): - return PullRequest.model_validate_json(text_output) +def parse_llm_output(text_output) -> Response: + return Response.model_validate_json(text_output) def create_file(path, content): @@ -47,6 +47,6 @@ def modify_file(path, content): def is_valid_json(text): try: json.loads(text) - return True except ValueError: return False + return True diff --git a/labs/repo.py b/labs/repo.py deleted file mode 100644 index 54dcddf..0000000 --- a/labs/repo.py +++ /dev/null @@ -1,90 +0,0 @@ -import logging -import subprocess - -from decorators import time_and_log_function -from github.github import GithubRequests -from parsers.response_parser import create_file, modify_file, parse_llm_output - -logger = logging.getLogger(__name__) - - -def clone_repository(repository_url, local_path): - logger.debug(f"Cloning repository from {repository_url}") - subprocess.run(["git", "clone", repository_url, local_path]) - - -@time_and_log_function -def get_issue(token, repository_owner, repository_name, username, issue_number): - github_request = GithubRequests( - github_token=token, - repository_owner=repository_owner, - repository_name=repository_name, - username=username, - ) - return github_request.get_issue(issue_number) - - -@time_and_log_function -def create_branch( - token, - repository_owner, - repository_name, - username, - issue_number, - issue_title, - original_branch="main", -): - github_request = GithubRequests( - github_token=token, - repository_owner=repository_owner, - repository_name=repository_name, - username=username, - ) - branch_name = f"{issue_number}-{issue_title}" - github_request.create_branch(branch_name=branch_name, original_branch=original_branch) - return branch_name - - -@time_and_log_function -def change_issue_to_in_progress(): - pass - - -@time_and_log_function -def commit_changes(token, repository_owner, repository_name, username, branch_name, file_list, message="Fix"): - github_request = GithubRequests( - github_token=token, - repository_owner=repository_owner, - repository_name=repository_name, - username=username, - ) - return github_request.commit_changes(message, branch_name, file_list) - - -@time_and_log_function -def create_pull_request(token, repository_owner, repository_name, username, original_branch, branch_name): - github_request = GithubRequests( - github_token=token, - repository_owner=repository_owner, - repository_name=repository_name, - username=username, - ) - return github_request.create_pull_request(branch_name, base=original_branch) - - -@time_and_log_function -def change_issue_to_in_review(): - pass - - -@time_and_log_function -def call_agent_to_apply_code_changes(llm_response): - pull_request = parse_llm_output(llm_response) - - files = [] - for step in pull_request.steps: - if step.type == "create": - files.append(create_file(step.path, step.content)) - elif step.type == "modify": - files.append(modify_file(step.path, step.content)) - return files diff --git a/labs/run.py b/labs/run.py deleted file mode 100644 index 538998d..0000000 --- a/labs/run.py +++ /dev/null @@ -1,59 +0,0 @@ -import logging - -import config.configuration_variables as settings -from decorators import time_and_log_function -from llm import call_llm_with_context -from repo import ( - call_agent_to_apply_code_changes, - clone_repository, - commit_changes, - create_branch, - create_pull_request, - get_issue, -) - -logger = logging.getLogger(__name__) - - -@time_and_log_function -def run_on_repository(token, repository_owner, repository_name, username, issue_number, original_branch="main"): - issue = get_issue(token, repository_owner, repository_name, username, issue_number) - issue_title = issue["title"].replace(" ", "-") - issue_summary = issue["body"] - - branch_name = create_branch( - token, - repository_owner, - repository_name, - username, - issue_number, - issue_title, - original_branch, - ) - - repository_url = f"https://github.com/{repository_owner}/{repository_name}" - logger.debug(f"Cloning repository from {repository_url}") - - repository_path = f"{settings.CLONE_DESTINATION_DIR}{repository_owner}/{repository_name}" - clone_repository(repository_url, repository_path) - - success, llm_response = call_llm_with_context(repository_path, issue_summary) - if not success: - logger.error("Failed to get a response from LLM, aborting run.") - return - - response_output = call_agent_to_apply_code_changes(llm_response[1].choices[0].message.content) - - commit_changes(token, repository_owner, repository_name, username, branch_name, response_output) - create_pull_request(token, repository_owner, repository_name, username, branch_name) - - -@time_and_log_function -def run_on_local_repo(repository_path, issue_text): - success, llm_response = call_llm_with_context(repository_path, issue_text) - if not success: - logger.error("Failed to get a response from LLM, aborting run.") - return - - response_output = call_agent_to_apply_code_changes(llm_response[1].choices[0].message.content) - return True, response_output diff --git a/labs/tasks/checks.py b/labs/tasks/checks.py new file mode 100644 index 0000000..fc0c13b --- /dev/null +++ b/labs/tasks/checks.py @@ -0,0 +1,66 @@ +import logging + +from parsers.response import is_valid_json, parse_llm_output +from pydantic import ValidationError as PydanticValidationError + +logger = logging.getLogger(__name__) + + +class ValidationError(ValueError): + pass + + +def check_length(llm_response): + finish_reason = getattr(llm_response["choices"][0]["message"], "finish_reason", None) + if finish_reason == "length": + raise ValidationError("Conversation was too long for the context window, resulting in incomplete JSON.") + + +def check_content_filter(llm_response): + finish_reason = getattr(llm_response["choices"][0]["message"], "finish_reason", None) + if finish_reason == "content-filter": + raise ValidationError( + "Model's output included restricted content. Generation of JSON was halted and may be partial." + ) + + +def check_refusal(llm_response): + refusal_reason = getattr(llm_response["choices"][0]["message"], "refusal", None) + if refusal_reason: + raise ValidationError(f"OpenAI safety system refused the request. Reason: {refusal_reason}") + + +def check_invalid_json(llm_response): + response_content = llm_response["choices"][0]["message"]["content"] + if not is_valid_json(response_content): + raise ValidationError("Malformed JSON LLM response.") + + try: + parse_llm_output(response_content) + + except PydanticValidationError: + raise ValidationError("JSON response from LLM does not match the expected format.") + + +check_list = [ + check_length, + check_content_filter, + check_refusal, + check_invalid_json, +] + + +def run_response_checks(llm_response): + for check in check_list: + logger.debug(f"Running LLM response check {check.__name__}") + + try: + check(llm_response[1]) + + except ValidationError as validation_error: + return True, str(validation_error) + + except Exception as error: + return True, str(error) + + return False, "" diff --git a/labs/tasks/llm.py b/labs/tasks/llm.py index d920d29..d075e71 100644 --- a/labs/tasks/llm.py +++ b/labs/tasks/llm.py @@ -1,19 +1,77 @@ import json +import logging import config.configuration_variables as settings -import redis from config.celery import app from core.models import Model, VectorizerModel from embeddings.embedder import Embedder -from embeddings.vectorizers.base import Vectorizer -from llm import get_llm_response, get_prompt, prepare_context +from embeddings.vectorizers.vectorizer import Vectorizer +from llm.requester import Requester +from tasks.checks import run_response_checks +from tasks.redis_client import RedisStrictClient, RedisVariable -redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) +logger = logging.getLogger(__name__) +redis_client = RedisStrictClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) + + +def get_prompt(issue_summary): + return f""" + You're a diligent software engineer AI. You can't see, draw, or interact with a + browser, but you can read and write files, and you can think. + You've been given the following task: {issue_summary}. + Any imports will be at the beggining of the file. + Add tests for the new functionalities, considering any existing test files. + The file paths provided are **absolute paths relative to the project root**, + and **must not be changed**. Ensure the paths you output match the paths provided exactly. + Do not prepend or modify the paths. + Please provide a json response in the following format: {{"steps": [...]}} + Where steps is a list of objects where each object contains three fields: + type, which is either 'create' to add a new file or 'modify' to edit an existing one; + If the file is to be modified send the finished version of the entire file. + path, which is the absolute path of the file to create/modify; + content, which is the content to write to the file. + """ + + +def get_context(embeddings, prompt): + context = [] + for file in embeddings: + context.append(dict(role="system", content=f"File: {file[1]}, Content: {file[2]}")) + + context.append(dict(role="user", content=prompt)) + return context + + +def get_llm_response(prompt): + llm_requester, *llm_requester_args = Model.get_active_llm_model() + requester = Requester(llm_requester, *llm_requester_args) + + retries, max_retries = 0, 5 + is_invalid, reason = True, None + + llm_response = None + while is_invalid and retries < max_retries: + try: + llm_response = requester.completion_without_proxy(prompt) + logger.debug(f"LLM reponse: {llm_response}") + + is_invalid, reason = run_response_checks(llm_response) + + except Exception as e: + is_invalid, reason = True, str(e) + logger.error("Invalid LLM response:", exc_info=e) + + if is_invalid: + retries += 1 + llm_response = None + logger.info(f"Redoing LLM response request doe to error (retries: {retries} of {max_retries}): {reason}") + + return True, llm_response @app.task def vectorize_repository_task(prefix="", repository_path=""): - repository_path = redis_client.get(f"{prefix}_repository_path") if prefix else repository_path + repository_path = redis_client.get(RedisVariable.REPOSITORY_PATH, prefix=prefix, default=repository_path) embedder_class, *embeder_args = Model.get_active_embedding_model() embedder = Embedder(embedder_class, *embeder_args) @@ -30,42 +88,43 @@ def vectorize_repository_task(prefix="", repository_path=""): def find_embeddings_task(prefix="", issue_body="", repository_path=""): embedder_class, *embeder_args = Model.get_active_embedding_model() embeddings_results = Embedder(embedder_class, *embeder_args).retrieve_embeddings( - redis_client.get(f"{prefix}_issue_body") if prefix else issue_body, - redis_client.get(f"f{prefix}_repository_path") if prefix else repository_path, + redis_client.get(RedisVariable.ISSUE_BODY, prefix=prefix, default=issue_body), + redis_client.get(RedisVariable.REPOSITORY_PATH, prefix=prefix, default=repository_path), ) similar_embeddings = [ (embedding.repository, embedding.file_path, embedding.text) for embedding in embeddings_results ] if prefix: - redis_client.set(f"{prefix}_similar_embeddings", json.dumps(similar_embeddings)) + redis_client.set(RedisVariable.EMBEDDINGS, prefix=prefix, value=json.dumps(similar_embeddings)) return prefix return similar_embeddings @app.task def prepare_prompt_and_context_task(prefix="", issue_body="", embeddings=[]): - prompt = get_prompt(redis_client.get(f"{prefix}_issue_body") if prefix else issue_body) - redis_client.set(f"{prefix}_prompt", prompt) + prompt = get_prompt(redis_client.get(RedisVariable.ISSUE_BODY, prefix=prefix, default=issue_body)) + redis_client.set(RedisVariable.PROMPT, prefix=prefix, value=prompt) - embeddings = json.loads(redis_client.get(f"{prefix}_similar_embeddings")) if prefix else embeddings - prepared_context = prepare_context(embeddings, prompt) + embeddings = json.loads(redis_client.get(RedisVariable.EMBEDDINGS, prefix=prefix, default=embeddings)) + prepared_context = get_context(embeddings, prompt) if prefix: - redis_client.set(f"{prefix}_prepared_context", json.dumps(prepared_context)) + redis_client.set(RedisVariable.CONTEXT, prefix=prefix, value=json.dumps(prepared_context)) return prefix return prepared_context @app.task def get_llm_response_task(prefix="", context={}): - context = json.loads(redis_client.get(f"{prefix}_prepared_context")) if prefix else context + context = json.loads(redis_client.get(RedisVariable.CONTEXT, prefix=prefix, default=context)) llm_response = get_llm_response(context) if prefix: redis_client.set( - f"{prefix}_llm_response", - llm_response[1][1]["choices"][0]["message"]["content"], + RedisVariable.LLM_RESPONSE, + prefix=prefix, + value=llm_response[1][1]["choices"][0]["message"]["content"], ) return prefix return llm_response diff --git a/labs/tasks/redis_client.py b/labs/tasks/redis_client.py new file mode 100644 index 0000000..176b60f --- /dev/null +++ b/labs/tasks/redis_client.py @@ -0,0 +1,56 @@ +from enum import Enum +from typing import Union + +from redis import StrictRedis +from redis.typing import EncodableT, ResponseT + + +class RedisVariable(Enum): + BRANCH_NAME = "branch_name" + CONTEXT = "context" + EMBEDDINGS = "embeddings" + FILES_MODIFIED = "files_modified" + ISSUE_BODY = "issue_body" + ISSUE_NUMBER = "issue_number" + ISSUE_TITLE = "issue_title" + LLM_RESPONSE = "llm_response" + ORIGINAL_BRANCH_NAME = "original_branch_name" + PROMPT = "prompt" + REPOSITORY_NAME = "repository_name" + REPOSITORY_OWNER = "repository_owner" + REPOSITORY_PATH = "repository_path" + TOKEN = "token" + USERNAME = "username" + + +class RedisStrictClient(StrictRedis): + def get( + self, variable: Union[str, RedisVariable], prefix: str = None, default: str | list | dict = None + ) -> ResponseT: + name = variable + if isinstance(variable, RedisVariable): + name = variable.value + + if prefix: + name = f"{prefix}_{name}" + + if self.exists(name): + return super().get(name) + return default + + def set( + self, + variable: Union[str, RedisVariable], + value: EncodableT, + prefix: str = None, + *args, + **kwargs, + ) -> ResponseT: + name = variable + if isinstance(variable, RedisVariable): + name = variable.value + + if prefix: + name = f"{prefix}_{name}" + + return super().set(name, value, *args, **kwargs) diff --git a/labs/tasks/repository.py b/labs/tasks/repository.py index 496748a..78fc250 100644 --- a/labs/tasks/repository.py +++ b/labs/tasks/repository.py @@ -1,32 +1,59 @@ import json -import logging import config.configuration_variables as settings -import redis from config.celery import app -from repo import call_agent_to_apply_code_changes, clone_repository -from run import commit_changes, create_branch, create_pull_request, get_issue +from decorators import time_and_log_function +from github.github import GithubRequests +from parsers.response import create_file, modify_file, parse_llm_output +from tasks.redis_client import RedisStrictClient, RedisVariable -logger = logging.getLogger(__name__) +redis_client = RedisStrictClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) -redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) + +def github_repository_data(prefix, token="", repository_owner="", repository_name="", username="") -> dict: + return dict( + token=redis_client.get(RedisVariable.TOKEN, prefix, default=token), + repository_owner=redis_client.get(RedisVariable.REPOSITORY_OWNER, prefix, default=repository_owner), + repository_name=redis_client.get(RedisVariable.REPOSITORY_NAME, prefix, default=repository_name), + username=redis_client.get(RedisVariable.USERNAME, prefix, default=username), + ) + + +@time_and_log_function +def apply_code_changes(llm_response): + response = parse_llm_output(llm_response) + + files = [] + for step in response.steps: + if step.type == "create": + files.append(create_file(step.path, step.content)) + elif step.type == "modify": + files.append(modify_file(step.path, step.content)) + + return files @app.task def get_issue_task(prefix="", token="", repository_owner="", repository_name="", issue_number="", username=""): - token = redis_client.get(f"{prefix}_token") if prefix else token - repository_owner = redis_client.get(f"{prefix}_repository_owner") if prefix else repository_owner - repository_name = redis_client.get(f"{prefix}_repository_name") if prefix else repository_name - username = redis_client.get(f"{prefix}_username") if prefix else username - issue_number = redis_client.get(f"{prefix}_issue_number") if prefix else issue_number + repository = github_repository_data(prefix, token, repository_owner, repository_name, username) + issue_number = redis_client.get(RedisVariable.ISSUE_NUMBER, prefix, default=issue_number) - issue = get_issue(token, repository_owner, repository_name, username, issue_number) + github_request = GithubRequests(**repository) + issue = github_request.get_issue(issue_number) if prefix: - issue_title = issue["title"].replace(" ", "-") - issue_body = issue["body"] - redis_client.set(f"{prefix}_issue_title", issue_title, ex=300) - redis_client.set(f"{prefix}_issue_body", issue_body, ex=300) + redis_client.set( + RedisVariable.ISSUE_TITLE, + prefix=prefix, + value=issue["title"].replace(" ", "-"), + ex=300, + ) + redis_client.set( + RedisVariable.ISSUE_BODY, + prefix=prefix, + value=issue["body"], + ex=300, + ) return prefix return issue @@ -42,52 +69,44 @@ def create_branch_task( original_branch="", issue_title="", ): - token = redis_client.get(f"{prefix}_token") if prefix else token - repository_owner = redis_client.get(f"{prefix}_repository_owner") if prefix else repository_owner - repository_name = redis_client.get(f"{prefix}_repository_name") if prefix else repository_name - username = redis_client.get(f"{prefix}_username") if prefix else username - issue_number = redis_client.get(f"{prefix}_issue_number") if prefix else issue_number - original_branch = redis_client.get(f"{prefix}_original_branch") if prefix else original_branch - issue_title = redis_client.get("issue_title") if prefix else issue_title - - branch_name = create_branch( - token, - repository_owner, - repository_name, - username, - issue_number, - issue_title, - original_branch, - ) + repository = github_repository_data(prefix, token, repository_owner, repository_name, username) + issue_number = redis_client.get(RedisVariable.ISSUE_NUMBER, prefix, default=issue_number) + original_branch = redis_client.get(RedisVariable.ORIGINAL_BRANCH_NAME, prefix, default=original_branch) + issue_title = redis_client.get(RedisVariable.ISSUE_TITLE, prefix, default=issue_title) + + branch_name = f"{issue_number}-{issue_title}" + + github_request = GithubRequests(**repository) + github_request.create_branch(branch_name=branch_name, original_branch=original_branch) if prefix: - redis_client.set(f"{prefix}_branch_name", branch_name, ex=300) + redis_client.set(RedisVariable.BRANCH_NAME, prefix=prefix, value=branch_name, ex=300) return prefix return branch_name @app.task def clone_repository_task(prefix="", repository_owner="", repository_name=""): - repository_owner = redis_client.get(f"{prefix}_repository_owner") if prefix else repository_owner - repository_name = redis_client.get(f"{prefix}_repository_name") if prefix else repository_name - repository_path = f"{settings.CLONE_DESTINATION_DIR}{repository_owner}/{repository_name}" - clone_repository(f"https://github.com/{repository_owner}/{repository_name}", repository_path) + repository = github_repository_data(prefix, repository_owner=repository_owner, repository_name=repository_name) + + github_request = GithubRequests(**repository) + repository_path = github_request.clone() if prefix: - redis_client.set(f"{prefix}_repository_path", repository_path, ex=300) + redis_client.set(RedisVariable.REPOSITORY_PATH, prefix=prefix, value=repository_path, ex=300) return prefix return True @app.task def apply_code_changes_task(prefix="", llm_response=""): - llm_response = redis_client.get(f"{prefix}_llm_response") if prefix else llm_response - files_modified = call_agent_to_apply_code_changes(llm_response) + llm_response = redis_client.get(RedisVariable.LLM_RESPONSE, prefix, default=llm_response) + modified_files = apply_code_changes(llm_response) if prefix: - redis_client.set(f"{prefix}_files_modified", json.dumps(files_modified)) + redis_client.set(RedisVariable.FILES_MODIFIED, prefix=prefix, value=json.dumps(modified_files)) return prefix - return files_modified + return modified_files @app.task @@ -98,15 +117,17 @@ def commit_changes_task( repository_name="", username="", branch_name="", - files_modified=[], + files_modified=None, ): - commit_changes( - token=redis_client.get(f"{prefix}_token") if prefix else token, - repository_owner=redis_client.get(f"{prefix}_repository_owner") if prefix else repository_owner, - repository_name=redis_client.get(f"{prefix}_repository_name") if prefix else repository_name, - username=redis_client.get(f"{prefix}_username") if prefix else username, - branch_name=(redis_client.get(f"{prefix}_branch_name") if prefix else branch_name), - file_list=(json.loads(redis_client.get(f"{prefix}_files_modified")) if prefix else files_modified), + if not files_modified: + files_modified = [] + + repository = github_repository_data(prefix, token, repository_owner, repository_name, username) + github_request = GithubRequests(**repository) + github_request.commit_changes( + message="Fix", + branch_name=redis_client.get(RedisVariable.BRANCH_NAME, prefix, default=branch_name), + files=json.loads(redis_client.get(RedisVariable.FILES_MODIFIED, prefix, default=files_modified)), ) if prefix: @@ -124,13 +145,11 @@ def create_pull_request_task( branch_name="", original_branch="", ): - create_pull_request( - token=redis_client.get(f"{prefix}_token") if prefix else token, - repository_owner=redis_client.get(f"{prefix}_repository_owner") if prefix else repository_owner, - repository_name=redis_client.get(f"{prefix}_repository_name") if prefix else repository_name, - username=redis_client.get(f"{prefix}_username") if prefix else username, - original_branch=(redis_client.get(f"{prefix}_original_branch") if prefix else original_branch), - branch_name=(redis_client.get(f"{prefix}_branch_name") if prefix else branch_name), + repository = github_repository_data(prefix, token, repository_owner, repository_name, username) + github_request = GithubRequests(**repository) + github_request.create_pull_request( + head=redis_client.get(RedisVariable.BRANCH_NAME, prefix, default=branch_name), + base=redis_client.get(RedisVariable.ORIGINAL_BRANCH_NAME, prefix, default=original_branch), ) if prefix: diff --git a/labs/tasks/run.py b/labs/tasks/run.py index f48b53f..d72ba71 100644 --- a/labs/tasks/run.py +++ b/labs/tasks/run.py @@ -1,8 +1,6 @@ -import logging import os.path import config.configuration_variables as settings -import redis from celery import chain from config.celery import app from tasks import ( @@ -17,20 +15,20 @@ prepare_prompt_and_context_task, vectorize_repository_task, ) +from tasks.redis_client import RedisStrictClient, RedisVariable -logger = logging.getLogger(__name__) - -redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) +redis_client = RedisStrictClient(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=0, decode_responses=True) @app.task(bind=True) def init_task(self, **kwargs): - if "repository_path" in kwargs: - if not os.path.exists(kwargs["repository_path"]): - raise FileNotFoundError(f"Directory {kwargs['repository_path']} does not exist") + path = RedisVariable.REPOSITORY_PATH.value + if path in kwargs and not os.path.exists(kwargs[path]): + raise FileNotFoundError(f"Directory {kwargs[path]} does not exist") + prefix = self.request.id for k, v in kwargs.items(): - redis_client.set(f"{prefix}_{k}", v, ex=3600) + redis_client.set(k, v, prefix=prefix, ex=3600) return prefix @@ -44,12 +42,12 @@ def run_on_repository_task( original_branch: str = "main", ): data = { - "token": token, - "repository_owner": repository_owner, - "repository_name": repository_name, - "username": username, - "issue_number": issue_number, - "original_branch": original_branch, + RedisVariable.TOKEN.value: token, + RedisVariable.REPOSITORY_OWNER.value: repository_owner, + RedisVariable.REPOSITORY_NAME.value: repository_name, + RedisVariable.USERNAME.value: username, + RedisVariable.ISSUE_NUMBER.value: issue_number, + RedisVariable.ORIGINAL_BRANCH_NAME.value: original_branch, } chain( init_task.s(**data), @@ -67,11 +65,10 @@ def run_on_repository_task( @app.task -def run_on_local_repository_task(repository_path, issue_text): +def run_on_local_repository_task(repository_path, issue_body): data = { - "issue_text": issue_text, - "issue_body": issue_text, - "repository_path": repository_path, + RedisVariable.ISSUE_BODY.value: issue_body, + RedisVariable.REPOSITORY_PATH.value: repository_path, } chain( init_task.s(**data), diff --git a/labs/tests/test_github_requests.py b/labs/tests/test_github_requests.py index 94cfdbc..c8aaed3 100644 --- a/labs/tests/test_github_requests.py +++ b/labs/tests/test_github_requests.py @@ -16,11 +16,11 @@ def test_list_issues_default_parameters(self, mocker): mock_get.return_value.status_code = 200 mock_get.return_value.json.return_value = sample_response - github_token = "valid_token" + token = "valid_token" repository_owner = "owner_username" repository_name = "repository_name" username = "your_username" - github_requests = GithubRequests(github_token, repository_owner, repository_name, username) + github_requests = GithubRequests(token, repository_owner, repository_name, username) issues = github_requests.list_issues() @@ -28,7 +28,7 @@ def test_list_issues_default_parameters(self, mocker): mock_get.assert_called_once_with( f"https://api.github.com/repos/{repository_owner}/{repository_name}/issues", headers={ - "Authorization": f"token {github_token}", + "Authorization": f"token {token}", "Accept": "application/vnd.github.v3+json", }, params={ @@ -43,7 +43,7 @@ def test_list_issues_http_failure(self, mocker): mock_response.raise_for_status.side_effect = requests.exceptions.RequestException("HTTP Error") mocker.patch("requests.get", return_value=mock_response) - github_requests = GithubRequests(github_token="fake_token", repository_owner="owner", repository_name="repo") + github_requests = GithubRequests(token="fake_token", repository_owner="owner", repository_name="repo") issues = github_requests.list_issues() @@ -56,11 +56,11 @@ def test_get_issue_returns_correct_details(self, mocker): mock_get.return_value.status_code = 200 mock_get.return_value.json.return_value = sample_response - github_token = "valid_token" + token = "valid_token" repository_owner = "owner_username" repository_name = "repository_name" username = "your_username" - github_requests = GithubRequests(github_token, repository_owner, repository_name, username) + github_requests = GithubRequests(token, repository_owner, repository_name, username) issue = github_requests.get_issue(1) @@ -71,7 +71,7 @@ def test_handle_http_request_failure_get_issue(self, mocker): mock_response.raise_for_status.side_effect = requests.exceptions.RequestException("Mocked Request Exception") mocker.patch("requests.get", return_value=mock_response) - github_requests = GithubRequests(github_token="fake_token", repository_owner="owner", repository_name="repo") + github_requests = GithubRequests(token="fake_token", repository_owner="owner", repository_name="repo") issue = github_requests.get_issue(1) @@ -82,7 +82,7 @@ def test_change_issue_status(self, mocker): mock_response.json.return_value = {"status": "closed"} mocker.patch("requests.patch", return_value=mock_response) - github_requests = GithubRequests(github_token="fake_token", repository_owner="owner", repository_name="repo") + github_requests = GithubRequests(token="fake_token", repository_owner="owner", repository_name="repo") response = github_requests.change_issue_status(issue_number=1, status="closed") @@ -120,7 +120,7 @@ def test_commit_changes_successfully(self, mocker): mock_response_patch.json.return_value = {"sha": "fake_update_sha"} mocker.patch("requests.patch", return_value=mock_response_patch) - github_requests = GithubRequests(github_token="fake_token", repository_owner="owner", repository_name="repo") + github_requests = GithubRequests(token="fake_token", repository_owner="owner", repository_name="repo") result = github_requests.commit_changes( message="Commit message", @@ -137,7 +137,7 @@ def test_create_pull_request_default_parameters(self, mocker): mock_response.raise_for_status.return_value = None mocker.patch("requests.post", return_value=mock_response) - github_requests = GithubRequests(github_token="fake_token", repository_owner="owner", repository_name="repo") + github_requests = GithubRequests(token="fake_token", repository_owner="owner", repository_name="repo") pull_request = github_requests.create_pull_request(head="feature_branch") diff --git a/labs/tests/test_llm.py b/labs/tests/test_llm.py index 6f0f6f2..0875410 100644 --- a/labs/tests/test_llm.py +++ b/labs/tests/test_llm.py @@ -1,13 +1,16 @@ -from unittest import skip +from unittest import TestCase, skip from unittest.mock import patch import pytest -from core.models import Model +from core.models import Model, VectorizerModel +from embeddings.embedder import Embedder from embeddings.ollama import OllamaEmbedder from embeddings.openai import OpenAIEmbedder -from litellm_service.ollama import OllamaRequester -from litellm_service.openai import OpenAIRequester -from llm import call_llm_with_context, check_invalid_json_response +from embeddings.vectorizers.vectorizer import Vectorizer +from llm.ollama import OllamaRequester +from llm.openai import OpenAIRequester +from tasks.checks import ValidationError, check_invalid_json +from tasks.llm import get_context, get_llm_response, get_prompt from tests.constants import ( OLLAMA_EMBEDDING_MODEL_NAME, OLLAMA_LLM_MODEL_NAME, @@ -16,6 +19,25 @@ ) +def call_llm_with_context(repository_path, issue_summary): + if not issue_summary: + raise ValueError("issue_summary cannot be empty.") + + embedder_class, *embeder_args = Model.get_active_embedding_model() + embedder = Embedder(embedder_class, *embeder_args) + + vectorizer_class = VectorizerModel.get_active_vectorizer() + Vectorizer(vectorizer_class, embedder).vectorize_to_database(None, repository_path) + + # find_similar_embeddings narrows down codebase to files that matter for the issue at hand. + context = embedder.retrieve_embeddings(issue_summary, repository_path) + + prompt = get_prompt(issue_summary) + prepared_context = get_context(context, prompt) + + return get_llm_response(prepared_context) + + class TestCallLLMWithContext: def test_empty_summary(self): repository_path = "repository_path" @@ -27,7 +49,7 @@ def test_empty_summary(self): assert "issue_summary cannot be empty" in str(excinfo.value) -class TestCheckInvalidJsonResponse: +class TestCheckInvalidJsonResponse(TestCase): def test_valid_json_response(self): llm_response = { "choices": [ @@ -38,9 +60,8 @@ def test_valid_json_response(self): } ] } - is_invalid, message = check_invalid_json_response(llm_response) - assert not is_invalid - assert message == "" + + check_invalid_json(llm_response) def test_invalid_json_response(self): llm_response = { @@ -52,15 +73,15 @@ def test_invalid_json_response(self): } ] } - is_invalid, message = check_invalid_json_response(llm_response) - assert is_invalid - assert message == "Invalid JSON response." + + with self.assertRaises(ValidationError, msg="Malformed JSON LLM response."): + check_invalid_json(llm_response) def test_invalid_json_structure(self): - llm_response = {"choices": [{"message": {"content": '{"invalid_key": invalid_value"}'}}]} - is_invalid, message = check_invalid_json_response(llm_response) - assert is_invalid - assert message == "Invalid JSON response." + llm_response = {"choices": [{"message": {"content": '{"invalid_key": "invalid_value"}'}}]} + + with self.assertRaises(ValidationError, msg="JSON response from LLM not match the expected format."): + check_invalid_json(llm_response) class TestLocalLLM: @@ -75,28 +96,28 @@ def test_local_llm_connection(self, mocked_context, mocked_vectorize_to_database assert success - @patch("llm.validate_llm_response") - @patch("embeddings.vectorizers.base.Vectorizer.vectorize_to_database") - @patch("litellm_service.ollama.OllamaRequester.completion_without_proxy") + @patch("tasks.llm.run_response_checks") + @patch("embeddings.vectorizers.vectorizer.Vectorizer.vectorize_to_database") + @patch("llm.ollama.OllamaRequester.completion_without_proxy") @patch("embeddings.embedder.Embedder.retrieve_embeddings") @pytest.mark.django_db def test_local_llm_redirect( self, - mocked_context, - mocked_local_llm, + mocked_retrieve_embeddings, + mocked_completion_without_proxy, mocked_vectorize_to_database, - mocked_validate_llm_reponse, + mocked_run_response_checks, create_test_ollama_llm_config, create_test_ollama_embedding_config, create_test_chunk_vectorizer_config, ): - mocked_context.return_value = [["file1", "/path/to/file1", "content"]] - mocked_validate_llm_reponse.return_value = False, "" - repository_destination = "repo" + mocked_retrieve_embeddings.return_value = [["file1", "/path/to/file1", "content"]] + mocked_run_response_checks.return_value = False, "" + repository_path = "repo" issue_summary = "Fix the bug in the authentication module" - call_llm_with_context(repository_destination, issue_summary) + call_llm_with_context(repository_path, issue_summary) - mocked_local_llm.assert_called_once() + mocked_completion_without_proxy.assert_called_once() class TestLLMRequester: diff --git a/notebooks/embeddings_exploring.ipynb b/notebooks/embeddings_exploring.ipynb index 1d60995..b3e5300 100644 --- a/notebooks/embeddings_exploring.ipynb +++ b/notebooks/embeddings_exploring.ipynb @@ -253,9 +253,13 @@ }, { "cell_type": "code", - "execution_count": null, "id": "296ba8adc2da03d5", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-02T14:43:29.658280Z", + "start_time": "2024-12-02T14:43:29.648171Z" + } + }, "source": [ "import os\n", "\n", @@ -272,6 +276,7 @@ " parser.visit(tree)\n", " return parser.get_structure()" ], + "execution_count": 1, "outputs": [] }, { diff --git a/notebooks/qwen+nomic.ipynb b/notebooks/qwen+nomic.ipynb index 8a2dd67..341ff43 100644 --- a/notebooks/qwen+nomic.ipynb +++ b/notebooks/qwen+nomic.ipynb @@ -9,6 +9,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "import os\n", @@ -16,8 +17,7 @@ "\n", "sys.path.append(os.path.abspath(\"../labs\"))" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -28,6 +28,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "from django.core.wsgi import get_wsgi_application\n", @@ -40,11 +41,11 @@ "\n", "application = get_wsgi_application()" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "EMBEDDER_MODEL_NAME = \"nomic-embed-text:latest\"\n", @@ -53,8 +54,7 @@ "REPO = \"REPLACE THIS WITH REPO PATH\"\n", "ISSUE = \"Add created_at and updated_at field to User model.\"" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -65,6 +65,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "from embeddings.embedder import Embedder\n", @@ -72,8 +73,7 @@ "\n", "embedder = Embedder(OllamaEmbedder, EMBEDDER_MODEL_NAME)" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -84,15 +84,15 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ - "from embeddings.vectorizers.base import Vectorizer\n", "from embeddings.vectorizers.chunk_vectorizer import ChunkVectorizer\n", + "from embeddings.vectorizers.vectorizer import Vectorizer\n", "\n", - "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repo_destination=REPO)" + "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repository_path=REPO)" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -103,23 +103,23 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "embeddings_results = embedder.retrieve_embeddings(ISSUE, REPO)\n", "\n", "similar_embeddings = [(embedding.repository, embedding.file_path, embedding.text) for embedding in embeddings_results]" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "similar_embeddings" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -130,15 +130,15 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ - "from llm import get_prompt, prepare_context\n", + "from tasks.llm import get_context, get_prompt\n", "\n", "prompt = get_prompt(ISSUE)\n", - "prepared_context = prepare_context(similar_embeddings, prompt)" + "prepared_context = get_context(similar_embeddings, prompt)" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "markdown", @@ -149,21 +149,22 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ - "from litellm_service.llm_requester import Requester\n", - "from litellm_service.ollama import OllamaRequester\n", - "from llm import validate_llm_response\n", + "from llm.ollama import OllamaRequester\n", + "from llm.requester import Requester\n", + "from tasks.checks import run_response_checks\n", "\n", "requester = Requester(OllamaRequester, model=LLM_MODEL_NAME)\n", "llm_response = requester.completion_without_proxy(prepared_context)\n", - "redo, redo_reason = validate_llm_response(llm_response)" + "redo, redo_reason = run_response_checks(llm_response)" ], - "outputs": [], - "execution_count": null + "outputs": [] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, "source": [ "import json\n", @@ -175,18 +176,9 @@ "print(REPO)\n", "print(ISSUE)\n", "\n", - "pprint(json.loads(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\n\", \" \")))\n", - "# pprint(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\\\n\", \" \"))" + "pprint(json.loads(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\n\", \" \")))" ], - "outputs": [], - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "code", - "source": "", - "outputs": [], - "execution_count": null + "outputs": [] } ], "metadata": { diff --git a/notebooks/sample.ipynb b/notebooks/sample.ipynb index c1a81eb..0404d35 100644 --- a/notebooks/sample.ipynb +++ b/notebooks/sample.ipynb @@ -16,13 +16,13 @@ "start_time": "2024-11-22T12:01:40.439141Z" } }, - "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "sys.path.append(os.path.abspath(\"../labs\"))" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -40,7 +40,6 @@ "start_time": "2024-11-22T12:01:41.930888Z" } }, - "outputs": [], "source": [ "from django.core.wsgi import get_wsgi_application\n", "\n", @@ -51,7 +50,8 @@ "os.environ.setdefault(\"DATABASE_HOST\", \"localhost\")\n", "\n", "application = get_wsgi_application()" - ] + ], + "outputs": [] }, { "cell_type": "code", @@ -62,13 +62,13 @@ "start_time": "2024-11-22T12:02:10.637010Z" } }, - "outputs": [], "source": [ "ISSUE = \"ADD YOUR ISSUE TEXT HERE\"\n", "EMBEDDER_MODEL_NAME = \"nomic-embed-text:latest\"\n", "LLM_MODEL_NAME = \"llama3.2:latest\"\n", "REPO = \".\"" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -86,13 +86,13 @@ "start_time": "2024-11-22T11:53:02.469957Z" } }, - "outputs": [], "source": [ "from embeddings.embedder import Embedder\n", "from embeddings.ollama import OllamaEmbedder\n", "\n", "embedder = Embedder(OllamaEmbedder, EMBEDDER_MODEL_NAME)" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -110,13 +110,13 @@ "start_time": "2024-11-22T11:53:03.690512Z" } }, - "outputs": [], "source": [ - "from embeddings.vectorizers.base import Vectorizer\n", "from embeddings.vectorizers.chunk_vectorizer import ChunkVectorizer\n", + "from embeddings.vectorizers.vectorizer import Vectorizer\n", "\n", - "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repo_destination=REPO)" - ] + "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repository_path=REPO)" + ], + "outputs": [] }, { "cell_type": "markdown", @@ -129,21 +129,21 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "embeddings_results = embedder.retrieve_embeddings(ISSUE, REPO)\n", "\n", "similar_embeddings = [(embedding.repository, embedding.file_path, embedding.text) for embedding in embeddings_results]" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "similar_embeddings" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -156,13 +156,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "from llm import get_prompt, prepare_context\n", + "from tasks.llm import get_context, get_prompt\n", "\n", "prompt = get_prompt(ISSUE)\n", - "prepared_context = prepare_context(similar_embeddings, prompt)" - ] + "prepared_context = get_context(similar_embeddings, prompt)" + ], + "outputs": [] }, { "cell_type": "markdown", @@ -175,27 +175,27 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "from litellm_service.llm_requester import Requester\n", - "from litellm_service.ollama import OllamaRequester\n", - "from llm import validate_llm_response\n", + "from llm.ollama import OllamaRequester\n", + "from llm.requester import Requester\n", + "from tasks.checks import run_response_checks\n", "\n", "requester = Requester(OllamaRequester, model=LLM_MODEL_NAME)\n", "llm_response = requester.completion_without_proxy(prepared_context)\n", - "redo, redo_reason = validate_llm_response(llm_response)" - ] + "redo, redo_reason = run_response_checks(llm_response)" + ], + "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from pprint import pprint\n", "\n", "pprint(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\\\n\", \" \"))" - ] + ], + "outputs": [] } ], "metadata": { diff --git a/notebooks/starcoder2+nomic.ipynb b/notebooks/starcoder2+nomic.ipynb index b3de2c0..6ac0227 100644 --- a/notebooks/starcoder2+nomic.ipynb +++ b/notebooks/starcoder2+nomic.ipynb @@ -11,13 +11,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "sys.path.append(os.path.abspath(\"../labs\"))" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -30,7 +30,6 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from django.core.wsgi import get_wsgi_application\n", "\n", @@ -41,20 +40,21 @@ "os.environ.setdefault(\"DATABASE_HOST\", \"localhost\")\n", "\n", "application = get_wsgi_application()" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "EMBEDDER_MODEL_NAME = \"nomic-embed-text:latest\"\n", "LLM_MODEL_NAME = \"starcoder2:15b-instruct\"\n", "\n", "REPO = \"REPLACE THIS WITH REPO PATH\"\n", "ISSUE = \"Add created_at and updated_at field to User model.\"" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -67,13 +67,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "from embeddings.embedder import Embedder\n", "from embeddings.ollama import OllamaEmbedder\n", "\n", "embedder = Embedder(OllamaEmbedder, EMBEDDER_MODEL_NAME)" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -86,13 +86,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "from embeddings.vectorizers.base import Vectorizer\n", "from embeddings.vectorizers.chunk_vectorizer import ChunkVectorizer\n", + "from embeddings.vectorizers.vectorizer import Vectorizer\n", "\n", - "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repo_destination=REPO)" - ] + "Vectorizer(ChunkVectorizer, embedder).vectorize_to_database(None, repository_path=REPO)" + ], + "outputs": [] }, { "cell_type": "markdown", @@ -105,21 +105,21 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "embeddings_results = embedder.retrieve_embeddings(ISSUE, REPO)\n", "\n", "similar_embeddings = [(embedding.repository, embedding.file_path, embedding.text) for embedding in embeddings_results]" - ] + ], + "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "similar_embeddings" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -132,13 +132,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "from llm import get_prompt, prepare_context\n", + "from tasks.llm import get_context, get_prompt\n", "\n", "prompt = get_prompt(ISSUE)\n", - "prepared_context = prepare_context(similar_embeddings, prompt)" - ] + "prepared_context = get_context(similar_embeddings, prompt)" + ], + "outputs": [] }, { "cell_type": "markdown", @@ -151,22 +151,21 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ - "from litellm_service.llm_requester import Requester\n", - "from litellm_service.ollama import OllamaRequester\n", - "from llm import validate_llm_response\n", + "from llm.ollama import OllamaRequester\n", + "from llm.requester import Requester\n", + "from tasks.checks import run_response_checks\n", "\n", "requester = Requester(OllamaRequester, model=LLM_MODEL_NAME)\n", "llm_response = requester.completion_without_proxy(prepared_context)\n", - "redo, redo_reason = validate_llm_response(llm_response)" - ] + "redo, redo_reason = run_response_checks(llm_response)" + ], + "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "import json\n", "from pprint import pprint\n", @@ -177,9 +176,9 @@ "print(REPO)\n", "print(ISSUE)\n", "\n", - "pprint(json.loads(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\n\", \" \")))\n", - "# pprint(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\\\n\", \" \"))" - ] + "pprint(json.loads(llm_response[1][\"choices\"][0][\"message\"][\"content\"].replace(\"\\n\", \" \")))" + ], + "outputs": [] } ], "metadata": {