diff --git a/README.md b/README.md index 8e693263..2b453d46 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,37 @@ In the docker container, set the `OPENAI_KEY` env var to your [OpenAI key](https export OPENAI_KEY=xx-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx ``` -### Set up one or more tasks in SWE-bench +### (Fresh issue mode) Set up and run on new GitHub issues + +> [!NOTE] +> This section is for running AutoCodeRover on new GitHub issues. For running it on SWE-bench tasks, refer to [SWE-bench mode](#swe-bench-mode-set-up-and-run-on-swe-bench-tasks). + +If you want to use AutoCodeRover for new GitHub issues in a project, prepare the following: + +- Link to clone the project (used for `git clone ...`). +- Commit hash of the project version for AutoCodeRover to work on (used for `git checkout ...`). +- Link to the GitHub issue page. + +Then, in the docker container (or your local copy of AutoCodeRover), run the following commands to set up the target project +and generate patch: + +``` +cd /opt/auto-code-rover +conda activate auto-code-rover +PYTHONPATH=. python app/main.py --mode fresh_issue --output-dir output --setup-dir setup --model gpt-4-0125-preview --model-temperature 0.2 --fresh-task-id --clone-link --commit-hash --issue-link +``` + +The `` can be any string used to identify this issue. + +If patch generation is successful, the path to the generated patch will be printed in the end. + + +### (SWE-bench mode) Set up and run on SWE-bench tasks + +> [!NOTE] +> This section is for running AutoCodeRover on SWE-bench tasks. For running it on new GitHub issues, refer to [Fresh issue mode](#fresh-issue-mode-set-up-and-run-on-new-github-issues). + +#### Set up In the docker container, we need to first set up the tasks to run in SWE-bench (e.g., `django__django-11133`). The list of all tasks can be found in [`conf/swe_lite_tasks.txt`](conf/swe_lite_tasks.txt). @@ -108,7 +138,7 @@ A conda environment will also be created for this task instance. _If you want to set up multiple tasks together, put their ids in `tasks.txt` and follow the same steps._ -### Run a single task +#### Run a single task Before running the task (`django__django-11133` here), make sure it has been set up as mentioned [above](#set-up-one-or-more-tasks-in-swe-bench). @@ -120,7 +150,7 @@ PYTHONPATH=. python app/main.py --enable-layered --model gpt-4-0125-preview --se The output of the run can then be found in `output/`. For example, the patch generated for `django__django-11133` can be found at a location like this: `output/applicable_patch/django__django-11133_yyyy-MM-dd_HH-mm-ss/extracted_patch_1.diff` (the date-time field in the directory name will be different depending on when the experiment was run). -### Run multiple tasks +#### Run multiple tasks First, put the id's of all tasks to run in a file, one per line. Suppose this file is `tasks.txt`, the tasks can be run with diff --git a/app/api/manage.py b/app/api/manage.py index c2fe09a2..ced12487 100644 --- a/app/api/manage.py +++ b/app/api/manage.py @@ -70,15 +70,15 @@ def __init__( task_id: str, project_path: str, commit: str, - env_name: str, - repo_name: str, - pre_install_cmds: list[str], - install_cmd: str, - test_cmd: str, - test_patch: str, - testcases_passing: list[str], - testcases_failing: list[str], output_dir: str, + env_name: str | None = None, + repo_name: str | None = None, + pre_install_cmds: list[str] | None = None, + install_cmd: str | None = None, + test_cmd: str | None = None, + test_patch: str | None = None, + testcases_passing: list[str] | None = None, + testcases_failing: list[str] | None = None, do_install: bool = False, import_root: str = "src", ): @@ -90,16 +90,22 @@ def __init__( self.env_name = env_name self.repo_name = repo_name # additional installation commands after setup was done - self.pre_install_cmds: list[str] = pre_install_cmds + self.pre_install_cmds: list[str] = ( + [] if pre_install_cmds is None else pre_install_cmds + ) self.install_cmd: str = install_cmd # command to run tests self.test_cmd: str = test_cmd # the patch to testcases self.test_patch: str = test_patch # names of the passing testcases for this issue - self.testcases_passing: list[str] = testcases_passing + self.testcases_passing: list[str] = ( + [] if testcases_passing is None else testcases_passing + ) # names of the failing testcases for this issue - self.testcases_failing: list[str] = testcases_failing + self.testcases_failing: list[str] = ( + [] if testcases_failing is None else testcases_failing + ) # where to write our output self.output_dir = os.path.abspath(output_dir) @@ -118,11 +124,15 @@ def __init__( self.do_install() # apply the test modifications to this task - self.apply_test_patch() + if self.test_patch is not None: + self.apply_test_patch() # commit the current changes, so that resetting later do not erase them - with apputils.cd(self.project_path): - apputils.repo_commit_current_changes(self.logger) + if do_install or self.test_patch is not None: + # this means we have applied some changes to the repo before + # starting the actual workflow + with apputils.cd(self.project_path): + apputils.repo_commit_current_changes(self.logger) # build search manager self.search_manager = SearchManager(self.project_path) diff --git a/app/fresh_issue/common.py b/app/fresh_issue/common.py new file mode 100644 index 00000000..1dc4a6a6 --- /dev/null +++ b/app/fresh_issue/common.py @@ -0,0 +1,91 @@ +import json +import os +import shutil +from os.path import join as pjoin + +from app import utils as apputils +from app.fresh_issue import github + + +class FreshTask: + """ + Encapsulate everything required to run ACR on a fresh issue from the internet. + """ + + def __init__( + self, + task_id: str, + clone_link: str, + commit_hash: str, + issue_link: str, + setup_dir: str, + task_output_dir: str, + ): + self.task_id = task_id + self.clone_link = clone_link + self.commit_hash = commit_hash + self.issue_link = issue_link + # where to store output of ACR + self.task_output_dir = task_output_dir + # where the project source code is located + self.project_dir = self.setup_task_local(setup_dir) + self.problem_stmt, self.created_at = self.prepare_issue() + self.write_meta_file() + + def setup_task_local(self, setup_dir: str): + """ + Clone and check out the target project locally. + """ + # we are going to clone to this path - make sure it is not there yet + cloned_path = pjoin(setup_dir, self.task_id) + if os.path.isdir(cloned_path): + print( + f"Path {cloned_path} already exists. Removing it to get a fresh clone." + ) + shutil.rmtree(cloned_path) + # really clone the repo + cloned_path = apputils.clone_repo_and_checkout( + self.clone_link, self.commit_hash, setup_dir, self.task_id + ) + print(f"Cloned source code to {cloned_path}.") + return cloned_path + + def prepare_issue(self): + """ + Prepare problem statement from the online issue report. + """ + if "github.com" in self.issue_link: + retrieved_issue = github.get_github_issue_info(self.issue_link) + if retrieved_issue is None: + raise Exception( + f"Failed to retrieve issue information from {self.issue_link}" + ) + else: + title, body, created_at = retrieved_issue + problem_stmt = f"{title}\n{body}" + # save this issue into a file for reference + problem_stmt_file = pjoin(self.task_output_dir, "problem_statement.txt") + with open(problem_stmt_file, "w") as f: + f.write(problem_stmt) + return problem_stmt, created_at + else: + raise NotImplementedError("Only GitHub issues are supported for now.") + + def write_meta_file(self): + """ + Write a meta file for compatibility reasons with the swe-bench mode. + """ + meta_file = pjoin(self.task_output_dir, "meta.json") + meta = { + "task_info": { + "base_commit": self.commit_hash, + "created_at": self.created_at, + "problem_statement": self.problem_stmt, + "instance_id": self.task_id, + }, + "setup_info": { + "repo_path": self.project_dir, + }, + } + with open(meta_file, "w") as f: + json.dump(meta, f, indent=4) diff --git a/app/fresh_issue/github.py b/app/fresh_issue/github.py new file mode 100644 index 00000000..27a192a6 --- /dev/null +++ b/app/fresh_issue/github.py @@ -0,0 +1,22 @@ +import requests + + +def get_github_issue_info(issue_url: str) -> tuple[str, str, str] | None: + # Extract owner, repo, and issue number from the URL + # Example issue URL: https://github.com/owner/repo/issues/123 + _, owner, repo, _, issue_number = issue_url.rsplit("/", 4) + + api_url = f"https://api.github.com/repos/{owner}/{repo}/issues/{issue_number}" + response = requests.get(api_url) + + if response.status_code == 200: + issue_info = response.json() + # Extract relevant information from the issue + title = issue_info["title"] + body = issue_info["body"] + created_at = issue_info["created_at"] + + return title, body, created_at + else: + print(f"Failed to fetch issue information: {response.status_code}") + return None diff --git a/app/globals.py b/app/globals.py index cd71ce25..914d9941 100644 --- a/app/globals.py +++ b/app/globals.py @@ -5,13 +5,6 @@ # Overall output directory for results output_dir: str = "" -# whether to start conversation from fresh, or load from a conversation history. -# If None, start from fresh. -# If not None, continue from the conversation history stored in . -# is the value of this variable, and should points to a json file -# containing the past conversation history. -load_cache: str | None = None - # the model to use model: str = "gpt-3.5-turbo-0125" @@ -26,7 +19,7 @@ enable_sbfl: bool = False # whether to perform layered search -enable_layered: bool = False +enable_layered: bool = True # whether to perform our own validation enable_validation: bool = False diff --git a/app/main.py b/app/main.py index 477b1016..aea599ff 100644 --- a/app/main.py +++ b/app/main.py @@ -14,8 +14,10 @@ from app import globals, globals_mut, inference, log from app import utils as apputils from app.api.manage import ProjectApiManager +from app.fresh_issue.common import FreshTask from app.post_process import ( extract_organize_and_form_input, + get_final_patch_path, organize_and_form_input, reextract_organize_and_form_inputs, ) @@ -106,10 +108,12 @@ def run_one_task(task: Task) -> bool: ) try: + # create api manager and run project initialization routine in its init api_manager = ProjectApiManager( task_id, repo_path, base_commit, + task_output_dir, env_name, repo_name, pre_install_cmds, @@ -118,12 +122,13 @@ def run_one_task(task: Task) -> bool: test_patch, testcases_passing, testcases_failing, - task_output_dir, do_install=globals.do_install, ) except Exception as e: log.log_exception(logger, e) - run_status_message = f"Task {task_id} failed with exception: {e}." + run_status_message = ( + f"Task {task_id} failed with exception when creating API manager: {e}." + ) logger.handlers.clear() return False @@ -146,15 +151,8 @@ def run_one_task(task: Task) -> bool: run_ok = False run_status_message = "" try: - # create api manager and run project initialization routine in its init - if globals.load_cache is not None: - # NOTE: although we start from a history state, still creating a new - # output folder to store results from this run - run_ok = inference.continue_task_from_cache( - globals.load_cache, task_output_dir, api_manager - ) - else: - run_ok = inference.run_one_task(task_output_dir, api_manager, problem_stmt) + + run_ok = inference.run_one_task(task_output_dir, api_manager, problem_stmt) if run_ok: run_status_message = f"Task {task_id} completed successfully." else: @@ -228,168 +226,20 @@ def run_task_group(task_group_id: str, task_group_items: list[Task]) -> None: ) -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--setup-map", - type=str, - help="Path to json file that contains the setup information of the projects.", - ) - parser.add_argument( - "--tasks-map", - type=str, - help="Path to json file that contains the tasks information.", - ) - ## where to store run results - parser.add_argument( - "--output-dir", - type=str, - help="Path to the directory that stores the run results.", - ) - ## which tasks to be run - parser.add_argument( - "--task-list-file", - type=str, - help="Path to the file that contains all tasks ids to be run.", - ) - parser.add_argument("--task", type=str, help="Task id to be run.") - parser.add_argument( - "--num-processes", - type=str, - default=1, - help="Number of processes to run the tasks in parallel.", - ) - parser.add_argument( - "--load-cache", - type=str, - help="(Deprecated) Point to a json file which contains past conversation history. " - "Restart conversation from this file instead of starting from scratch. " - "Only available when running a single task.", - ) - parser.add_argument( - "--enable-sbfl", action="store_true", default=False, help="Enable SBFL." - ) - parser.add_argument( - "--enable-layered", - action="store_true", - default=False, - help="Enable layered code search.", - ) - parser.add_argument( - "--enable-validation", - action="store_true", - default=False, - help="Enable validation in our workflow.", - ) - parser.add_argument( - "--enable-angelic", - action="store_true", - default=False, - help="(Experimental) Enable angelic debugging", - ) - parser.add_argument( - "--enable-perfect-angelic", - action="store_true", - default=False, - help="(Experimental) Enable perfect angelic debugging; overrides --enable-angelic", - ) - parser.add_argument( - "--no-print", - action="store_true", - default=False, - help="Do not print most messages to stdout.", - ) - parser.add_argument( - "--model", - type=str, - default="gpt-3.5-turbo-0125", - choices=globals.MODELS, - help="The model to use. Currently only OpenAI models are supported.", - ) - parser.add_argument( - "--model-temperature", - type=float, - default=0.0, - help="The model temperature to use, for OpenAI models.", - ) - parser.add_argument( - "--conv-round-limit", - type=int, - default=15, - help="Conversation round limit for the main agent.", - ) - parser.add_argument( - "--extract-patches", - type=str, - help="Only extract patches from the raw results dir. Voids all other arguments if this is used.", - ) - parser.add_argument( - "--re-extract-patches", - type=str, - help="same as --extract-patches, except that individual dirs are moved out of their categories first", - ) - parser.add_argument( - "--save-sbfl-result", - action="store_true", - default=False, - help="Special mode to only save SBFL results for future runs.", - ) - - args = parser.parse_args() - setup_map_file = args.setup_map - tasks_map_file = args.tasks_map - globals.output_dir = args.output_dir - if globals.output_dir is not None: - globals.output_dir = apputils.convert_dir_to_absolute(globals.output_dir) - task_list_file = args.task_list_file - task_id = args.task - num_processes: int = int(args.num_processes) - globals.load_cache = args.load_cache - globals.model = args.model - globals.model_temperature = args.model_temperature - # set whether brief or verbose log - print_stdout: bool = not args.no_print - log.print_stdout = print_stdout - globals.enable_sbfl = args.enable_sbfl - globals.enable_layered = args.enable_layered - globals.enable_validation = args.enable_validation - globals.enable_angelic = args.enable_angelic - globals.enable_perfect_angelic = args.enable_perfect_angelic - globals.conv_round_limit = args.conv_round_limit - - # special modes - extract_patches: str | None = args.extract_patches - globals.only_save_sbfl_result = args.save_sbfl_result - - if globals.only_save_sbfl_result and extract_patches is not None: - raise ValueError( - "Cannot save SBFL result and extract patches at the same time." - ) - - # special mode 1: extract patch, for this we can early exit - if args.re_extract_patches is not None: - extract_patches = apputils.convert_dir_to_absolute(args.re_extract_patches) - reextract_organize_and_form_inputs(args.re_extract_patches) - return - - if extract_patches is not None: - extract_patches = apputils.convert_dir_to_absolute(extract_patches) - extract_organize_and_form_input(extract_patches) - return - - globals.do_install = ( - globals.enable_sbfl - or globals.enable_validation - or globals.only_save_sbfl_result - ) - +def entry_swe_bench_mode( + task_id: str | None, + task_list_file: str | None, + setup_map_file: str, + tasks_map_file: str, + num_processes: int, +): + """ + Main entry for swe-bench mode. + """ # check parameters if task_id is not None and task_list_file is not None: raise ValueError("Cannot specify both task and task-list.") - if globals.load_cache is not None and task_id is None: - raise ValueError("Cannot load cache when not in single-task mode.") - all_task_ids = [] if task_list_file is not None: all_task_ids = parse_task_list_file(task_list_file) @@ -484,5 +334,304 @@ def main(): log.print_with_time("SWE-Bench input file created: " + swe_input_file) +def entry_fresh_issue_mode( + task_id: str, clone_link: str, commit_hash: str, issue_link: str, setup_dir: str +): + """ + Main entry for fresh issue mode. + """ + # create setup and output directories + apputils.create_dir_if_not_exists(setup_dir) + start_time = datetime.datetime.now() + start_time_s = start_time.strftime("%Y-%m-%d_%H-%M-%S") + task_output_dir = pjoin(globals.output_dir, task_id + "_" + start_time_s) + apputils.create_dir_if_not_exists(task_output_dir) + + fresh_task = FreshTask( + task_id, clone_link, commit_hash, issue_link, setup_dir, task_output_dir + ) + logger = log.create_new_logger(task_id, task_output_dir) + log.log_and_always_print( + logger, + f"============= Running fresh issue {task_id} =============", + ) + + try: + api_manager = ProjectApiManager( + task_id, fresh_task.project_dir, commit_hash, task_output_dir + ) + except Exception as e: + log.log_exception(logger, e) + run_status_message = f"Fresh issue {task_id} failed with exception when creating API manager: {e}." + return False + + run_ok = False + run_status_message = "" + try: + run_ok = inference.run_one_task( + task_output_dir, api_manager, fresh_task.problem_stmt + ) + if run_ok: + run_status_message = f"Fresh issue {task_id} completed successfully." + else: + run_status_message = f"Fresh issue {task_id} failed without exception." + except Exception as e: + log.log_exception(logger, e) + run_status_message = f"Fresh issue {task_id} failed with exception: {e}." + finally: + # dump recorded tool call sequence into a file + end_time = datetime.datetime.now() + + api_manager.dump_tool_call_sequence_to_file() + api_manager.dump_tool_call_layers_to_file() + + input_cost_per_token = globals.MODEL_COST_PER_INPUT[globals.model] + output_cost_per_token = globals.MODEL_COST_PER_INPUT[globals.model] + with open(pjoin(task_output_dir, "cost.json"), "w") as f: + json.dump( + { + "model": globals.model, + "commit": commit_hash, + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "total_input_tokens": api_manager.input_tokens, + "total_output_tokens": api_manager.output_tokens, + "total_tokens": api_manager.input_tokens + + api_manager.output_tokens, + "total_cost": api_manager.cost, + "start_epoch": start_time.timestamp(), + "end_epoch": end_time.timestamp(), + "elapsed_seconds": (end_time - start_time).total_seconds(), + }, + f, + indent=4, + ) + + # at the end of each task, reset everything in the task repo to clean state + with apputils.cd(fresh_task.project_dir): + apputils.repo_reset_and_clean_checkout(commit_hash, logger) + log.log_and_always_print(logger, run_status_message) + final_patch_path = get_final_patch_path(task_output_dir) + if final_patch_path is not None: + log.log_and_always_print( + logger, f"Please find the generated patch at: {final_patch_path}" + ) + else: + log.log_and_always_print( + logger, "No patch generated. You can try to run ACR again." + ) + return run_ok + + +def main(): + parser = argparse.ArgumentParser() + ## Common options + # where to store run results + parser.add_argument( + "--mode", + default="swe_bench", + choices=["swe_bench", "fresh_issue"], + help="Choose to run tasks in SWE-bench, or a fresh issue from the internet.", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Path to the directory that stores the run results.", + ) + parser.add_argument( + "--num-processes", + type=str, + default=1, + help="Number of processes to run the tasks in parallel.", + ) + parser.add_argument( + "--no-print", + action="store_true", + default=False, + help="Do not print most messages to stdout.", + ) + parser.add_argument( + "--model", + type=str, + default="gpt-3.5-turbo-0125", + choices=globals.MODELS, + help="The model to use. Currently only OpenAI models are supported.", + ) + parser.add_argument( + "--model-temperature", + type=float, + default=0.0, + help="The model temperature to use, for OpenAI models.", + ) + parser.add_argument( + "--conv-round-limit", + type=int, + default=15, + help="Conversation round limit for the main agent.", + ) + parser.add_argument( + "--extract-patches", + type=str, + help="Only extract patches from the raw results dir. Voids all other arguments if this is used.", + ) + parser.add_argument( + "--re-extract-patches", + type=str, + help="same as --extract-patches, except that individual dirs are moved out of their categories first", + ) + parser.add_argument( + "--enable-layered", + action="store_true", + default=True, + help="Enable layered code search.", + ) + + swe_group = parser.add_argument_group( + "swe_bench", description="Arguments for running on SWE-bench tasks." + ) + ## task info when running instances in SWE-bench + swe_group.add_argument( + "--setup-map", + type=str, + help="Path to json file that contains the setup information of the projects.", + ) + swe_group.add_argument( + "--tasks-map", + type=str, + help="Path to json file that contains the tasks information.", + ) + swe_group.add_argument( + "--task-list-file", + type=str, + help="Path to the file that contains all tasks ids to be run.", + ) + swe_group.add_argument("--task", type=str, help="Task id to be run.") + ## Only support test-based options for SWE-bench tasks for now + swe_group.add_argument( + "--enable-sbfl", action="store_true", default=False, help="Enable SBFL." + ) + swe_group.add_argument( + "--enable-validation", + action="store_true", + default=False, + help="Enable validation in our workflow.", + ) + swe_group.add_argument( + "--enable-angelic", + action="store_true", + default=False, + help="(Experimental) Enable angelic debugging", + ) + swe_group.add_argument( + "--enable-perfect-angelic", + action="store_true", + default=False, + help="(Experimental) Enable perfect angelic debugging; overrides --enable-angelic", + ) + swe_group.add_argument( + "--save-sbfl-result", + action="store_true", + default=False, + help="Special mode to only save SBFL results for future runs.", + ) + + fresh_group = parser.add_argument_group( + "fresh_issue", + description="Arguments for running on fresh issues from the internet.", + ) + ## task info when running on new issues from GitHub + fresh_group.add_argument( + "--fresh-task-id", + type=str, + help="Assign an id to the current fresh issue task.", + ) + fresh_group.add_argument( + "--clone-link", + type=str, + help="[Fresh issue] The link to the repository to clone.", + ) + fresh_group.add_argument( + "--commit-hash", type=str, help="[Fresh issue] The commit hash to checkout." + ) + fresh_group.add_argument( + "--issue-link", type=str, help="[Fresh issue] The link to the issue." + ) + fresh_group.add_argument( + "--setup-dir", + type=str, + help="[Fresh issue] The directory where repositories should be cloned to.", + ) + + args = parser.parse_args() + ## common options + mode = args.mode + globals.output_dir = args.output_dir + if globals.output_dir is not None: + globals.output_dir = apputils.convert_dir_to_absolute(globals.output_dir) + num_processes: int = int(args.num_processes) + # set whether brief or verbose log + print_stdout: bool = not args.no_print + log.print_stdout = print_stdout + globals.model = args.model + globals.model_temperature = args.model_temperature + globals.conv_round_limit = args.conv_round_limit + extract_patches: str | None = args.extract_patches + re_extract_patches: str | None = args.re_extract_patches + globals.enable_layered = args.enable_layered + + ## options for swe-bench mode + setup_map_file = args.setup_map + tasks_map_file = args.tasks_map + task_list_file: str | None = args.task_list_file + task_id: str | None = args.task + globals.enable_sbfl = args.enable_sbfl + globals.enable_validation = args.enable_validation + globals.enable_angelic = args.enable_angelic + globals.enable_perfect_angelic = args.enable_perfect_angelic + globals.only_save_sbfl_result = args.save_sbfl_result + + ## options for fresh_issue mode + fresh_task_id = args.fresh_task_id + clone_link = args.clone_link + commit_hash = args.commit_hash + issue_link = args.issue_link + setup_dir = args.setup_dir + if setup_dir is not None: + setup_dir = apputils.convert_dir_to_absolute(setup_dir) + + ## Firstly deal with special modes + if globals.only_save_sbfl_result and extract_patches is not None: + raise ValueError( + "Cannot save SBFL result and extract patches at the same time." + ) + + # special mode 1: extract patch, for this we can early exit + if re_extract_patches is not None: + extract_patches = apputils.convert_dir_to_absolute(re_extract_patches) + reextract_organize_and_form_inputs(re_extract_patches) + return + + if extract_patches is not None: + extract_patches = apputils.convert_dir_to_absolute(extract_patches) + extract_organize_and_form_input(extract_patches) + return + + # we do not do install for fresh issue now + globals.do_install = (mode == "swe_bench") and ( + globals.enable_sbfl + or globals.enable_validation + or globals.only_save_sbfl_result + ) + + if mode == "swe_bench": + entry_swe_bench_mode( + task_id, task_list_file, setup_map_file, tasks_map_file, num_processes + ) + else: + entry_fresh_issue_mode( + fresh_task_id, clone_link, commit_hash, issue_link, setup_dir + ) + + if __name__ == "__main__": main() diff --git a/app/post_process.py b/app/post_process.py index 0577e9f3..4fe89094 100644 --- a/app/post_process.py +++ b/app/post_process.py @@ -113,9 +113,12 @@ def record_extract_status(individual_expr_dir: str, extract_status: ExtractStatu json.dump(record, f, indent=4) -def read_extract_status(individual_expr_dir: str) -> ExtractStatus: +def read_extract_status(individual_expr_dir: str) -> tuple[ExtractStatus, int]: """ Read extract status from file. If there are multiple status recorded, read the best one. + Returns: + - The best extract status + - The index of the best status in the list of all statuses. (0-based) """ # we should read from the all the record record_file = pjoin(individual_expr_dir, "extract_status.json") @@ -128,7 +131,23 @@ def read_extract_status(individual_expr_dir: str) -> ExtractStatus: # convert string to enum type all_status = [ExtractStatus(s) for s in record["extract_status"]] best_status = ExtractStatus.max(all_status) - return best_status + idx = all_status.index(best_status) + return best_status, idx + + +def get_final_patch_path(individual_expr_dir: str) -> str | None: + """ + Get the final patch path from the individual experiment directory. + If there are multiple extracted patches, need to figure out which one is the best based + on the patch extraction history. + """ + _, best_index = read_extract_status(individual_expr_dir) + best_patch_name = f"extracted_patch_{best_index + 1}.diff" + final_patch_path = pjoin(individual_expr_dir, best_patch_name) + if os.path.isfile(final_patch_path): + return final_patch_path + else: + return None def extract_diff_one_instance( @@ -270,7 +289,7 @@ def organize_experiment_results(expr_dir: str): os.makedirs(extract_status.to_dir_name(expr_dir), exist_ok=True) for task_dir in task_exp_dirs: - extract_status = read_extract_status(task_dir) + extract_status, _ = read_extract_status(task_dir) corresponding_dir = extract_status.to_dir_name(expr_dir) shutil.move(task_dir, corresponding_dir) diff --git a/app/search/search_manage.py b/app/search/search_manage.py index febdb289..ff9e76c0 100644 --- a/app/search/search_manage.py +++ b/app/search/search_manage.py @@ -37,21 +37,26 @@ def __build_index(self): value is a list of tuples. This is for fast lookup whenever we receive a query. """ - self.all_py_files = search_utils.get_all_py_files(self.project_path) + temp_all_py_file = search_utils.get_all_py_files(self.project_path) + # holds the parsable subset of all py files + parsed_all_py_file = [] + for py_file in temp_all_py_file: + file_info = search_utils.get_all_info_from_file(py_file) + if file_info is None: + # parsing of this file failed + continue + parsed_all_py_file.append(py_file) + # extract from file info, and form search index + classes, class_to_funcs, top_level_funcs = file_info - for py_file in self.all_py_files: - # print(py_file) # (1) build class index - classes = search_utils.get_all_classes_in_file(py_file) - # now put the class result in one file into the dict for c, start, end in classes: if c not in self.class_index: self.class_index[c] = [] self.class_index[c].append((py_file, start, end)) # (2) build class-function index - for c, _, _ in classes: - class_funcs = search_utils.get_all_funcs_in_class_in_file(py_file, c) + for c, class_funcs in class_to_funcs.items(): if c not in self.class_func_index: self.class_func_index[c] = dict() for f, start, end in class_funcs: @@ -60,12 +65,13 @@ def __build_index(self): self.class_func_index[c][f].append((py_file, start, end)) # (3) build (top-level) function index - functions = search_utils.get_top_level_functions(py_file) - for f, start, end in functions: + for f, start, end in top_level_funcs: if f not in self.function_index: self.function_index[f] = [] self.function_index[f].append((py_file, start, end)) + self.all_py_files = parsed_all_py_file + def file_line_to_class_and_func( self, file_path: str, line_no: int ) -> tuple[str | None, str | None]: diff --git a/app/search/search_utils.py b/app/search/search_utils.py index 7792f73a..8c849019 100644 --- a/app/search/search_utils.py +++ b/app/search/search_utils.py @@ -129,82 +129,53 @@ def get_all_py_files(dir_path: str) -> list[str]: return res -def get_all_classes_in_file(file_full_path: str) -> list[tuple[str, int, int]]: - """Get all classes defined in one .py file. - - Args: - file_path (str): Path to the .py file. - Returns: - List of classes in this file. +def get_all_info_from_file(file_full_path: str) -> tuple[list, dict, list] | None: """ + Main method to parse AST and build search index. + Handles complication where python ast module cannot parse a file. + """ + try: + with open(file_full_path) as f: + file_content = f.read() + tree = ast.parse(file_content) + except Exception: + # failed to read/parse one file, we should ignore it + return None - with open(file_full_path) as f: - file_content = f.read() - + # (1) get all classes defined in the file classes = [] - # print(file_path) - tree = ast.parse(file_content) + # (2) for each class in the file, get all functions defined in the class. + class_to_funcs = dict() + # (3) get top-level functions in the file (exclues functions defined in classes) + top_level_funcs = [] + for node in ast.walk(tree): if isinstance(node, ast.ClassDef): + ## class part (1): collect class info class_name = node.name start_lineno = node.lineno end_lineno = node.end_lineno # line numbers are 1-based classes.append((class_name, start_lineno, end_lineno)) - return classes - - -def get_top_level_functions(file_full_path: str) -> list[tuple[str, int, int]]: - """Get top-level functions defined in one .py file. - - This excludes functions defined in any classes. - - Args: - file_path (str): Path to the .py file. - Returns: - List of top-level functions in this file. - """ - with open(file_full_path) as f: - file_content = f.read() - - functions = [] - tree = ast.parse(file_content) - for node in tree.body: - if isinstance(node, ast.FunctionDef): - function_name = node.name - start_lineno = node.lineno - end_lineno = node.end_lineno - # line numbers are 1-based - functions.append((function_name, start_lineno, end_lineno)) - return functions - - -# mainly used for building index -def get_all_funcs_in_class_in_file( - file_full_path: str, class_name: str -) -> list[tuple[str, int, int]]: - """ - For a class in a file, get all functions defined in the class. - Assumption: - - the given function exists, and is defined in the given file. - Returns: - - List of tuples, each tuple is (function_name, start_lineno, end_lineno). - """ - with open(file_full_path) as f: - file_content = f.read() - functions = [] - tree = ast.parse(file_content) - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef) and node.name == class_name: + ## class part (2): collect function info inside this class + class_funcs = [] for n in ast.walk(node): if isinstance(n, ast.FunctionDef): function_name = n.name start_lineno = n.lineno end_lineno = n.end_lineno - functions.append((function_name, start_lineno, end_lineno)) + class_funcs.append((function_name, start_lineno, end_lineno)) + class_to_funcs[class_name] = class_funcs + + elif isinstance(node, ast.FunctionDef): + function_name = node.name + start_lineno = node.lineno + end_lineno = node.end_lineno + # line numbers are 1-based + top_level_funcs.append((function_name, start_lineno, end_lineno)) - return functions + return classes, class_to_funcs, top_level_funcs def get_func_snippet_in_class( diff --git a/app/utils.py b/app/utils.py index 69258e0b..e26146e3 100644 --- a/app/utils.py +++ b/app/utils.py @@ -42,6 +42,25 @@ def run_command(logger, cmd: list[str], **kwargs) -> subprocess.CompletedProcess return cp +def clone_repo_and_checkout( + clone_link: str, commit_hash: str, dest_dir: str, cloned_name: str +): + """ + Clone a repo to dest_dir, and checkout to commit `commit_hash`. + + Returns: + - path to the newly cloned directory. + """ + clone_cmd = ["git", "clone", clone_link, cloned_name] + checkout_cmd = ["git", "checkout", commit_hash] + with cd(dest_dir): + run_command(None, clone_cmd) + cloned_dir = pjoin(dest_dir, cloned_name) + with cd(cloned_dir): + run_command(None, checkout_cmd) + return cloned_dir + + def repo_commit_current_changes(logger=None): """ Commit the current active changes so that it's safer to do git reset later on. diff --git a/environment.yml b/environment.yml index e94328d5..18b79148 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - xz=5.4.5=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - unidiff + - requests - pip: - annotated-types==0.6.0 - anyio==4.2.0 @@ -42,6 +43,7 @@ dependencies: - pydantic==2.5.3 - pydantic-core==2.14.6 - python-dotenv==1.0.0 + - requests==2.31.0 - sniffio==1.3.0 - soupsieve==2.5 - tenacity==8.2.3