From ae3ea89b97ebe5c0d007d9b182e7636c6711cd42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Mon, 16 Sep 2024 11:41:51 +0200 Subject: [PATCH] fix: support multiple patches in a single codeblock (#118) --- gptme/tools/patch.py | 85 +++++++++++++++++++-------------------- tests/test_tools_patch.py | 55 +++++++++++++++++++++---- 2 files changed, 88 insertions(+), 52 deletions(-) diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 8f67251b..b838d9e8 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -50,54 +50,51 @@ def hello(): def apply(codeblock: str, content: str) -> str: """ - Applies the patch in ``codeblock`` to ``content``. + Applies multiple patches in ``codeblock`` to ``content``. """ - # TODO: support multiple patches in one codeblock, - # or make it clear that only one patch per codeblock is supported codeblock = codeblock.strip() - - # get the original and modified chunks - if ORIGINAL not in codeblock: # pragma: no cover - raise ValueError(f"invalid patch, no `{ORIGINAL.strip()}`", codeblock) - original = re.split(ORIGINAL, codeblock)[1] - - if DIVIDER not in original: # pragma: no cover - raise ValueError(f"invalid patch, no `{DIVIDER.strip()}`", codeblock) - original, modified = re.split(DIVIDER, original) - - if UPDATED not in "\n" + modified: # pragma: no cover - raise ValueError(f"invalid patch, no `{UPDATED.strip()}`", codeblock) - modified = re.split(UPDATED, modified)[0] - - # TODO: maybe allow modified chunk to contain "// ..." to refer to chunks in the original, - # and then replace these with the original chunks? - re_placeholder = re.compile(r"^[ \t]*(#|//) \.\.\. ?.*$", re.MULTILINE) - if re_placeholder.search(original) or re_placeholder.search(modified): - # raise ValueError("placeholders in modified chunk") - # split them by lines starting with "# ..." - originals = re_placeholder.split(original) - modifieds = re_placeholder.split(modified) - if len(originals) != len(modifieds): - raise ValueError( - "different number of placeholders in original and modified chunks" - f"\n{originals}\n{modifieds}" - ) - new = content - for orig, mod in zip(originals, modifieds): - if orig == mod: - continue - new = new.replace(orig, mod) - else: - if original not in content: # pragma: no cover - raise ValueError("original chunk not found in file", original) - - # replace the original chunk with the modified chunk - new = content.replace(original, modified) - - if new == content: # pragma: no cover + new_content = content + + # Split the codeblock into multiple patches + patches = re.split(f"(?={re.escape(ORIGINAL)})", codeblock) + + for patch in patches: + if not patch.strip(): + continue + + if ORIGINAL not in patch: # pragma: no cover + raise ValueError(f"invalid patch, no `{ORIGINAL.strip()}`", patch) + + parts = re.split( + f"{re.escape(ORIGINAL)}|{re.escape(DIVIDER)}|{re.escape(UPDATED)}", patch + ) + if len(parts) != 4: # pragma: no cover + raise ValueError("invalid patch format", patch) + + _, original, modified, _ = parts + + re_placeholder = re.compile(r"^[ \t]*(#|//) \.\.\. ?.*$", re.MULTILINE) + if re_placeholder.search(original) or re_placeholder.search(modified): + originals = re_placeholder.split(original) + modifieds = re_placeholder.split(modified) + if len(originals) != len(modifieds): + raise ValueError( + "different number of placeholders in original and modified chunks" + f"\n{originals}\n{modifieds}" + ) + for orig, mod in zip(originals, modifieds): + if orig == mod: + continue + new_content = new_content.replace(orig, mod) + else: + if original not in new_content: # pragma: no cover + raise ValueError("original chunk not found in file", original) + new_content = new_content.replace(original, modified) + + if new_content == content: # pragma: no cover raise ValueError("patch did not change the file") - return new + return new_content def apply_file(codeblock, filename): diff --git a/tests/test_tools_patch.py b/tests/test_tools_patch.py index 975a3e4e..18723561 100644 --- a/tests/test_tools_patch.py +++ b/tests/test_tools_patch.py @@ -1,13 +1,11 @@ from gptme.tools.patch import apply example_patch = """ -```patch filename.py <<<<<<< ORIGINAL original lines ======= modified lines >>>>>>> UPDATED -``` """ @@ -28,7 +26,6 @@ def hello(): """ codeblock = """ -```patch test.py <<<<<<< ORIGINAL def hello(): print("hello") @@ -36,7 +33,6 @@ def hello(): def hello(name="world"): print(f"hello {name}") >>>>>>> UPDATED -``` """ result = apply(codeblock, content) @@ -60,14 +56,12 @@ def hello(): # NOTE: test fails if UPDATED block doesn't have an empty line codeblock = """ -```patch test.py <<<<<<< ORIGINAL def hello(): print("hello") ======= >>>>>>> UPDATED -``` """ print(content) result = apply(codeblock, content) @@ -89,14 +83,59 @@ def hello(): hello() """ codeblock = """ -```patch test.py <<<<<<< ORIGINAL ======= >>>>>>> UPDATED -``` """ result = apply(codeblock, content) assert "\n\n\n" in result + + +def test_apply_multiple(): + # tests multiple patches in a single codeblock, with placeholders in patches + # checks that whitespace is preserved + content = """ +def hello(): + print("hello") + +if __name__ == "__main__": + hello() +""" + codeblock = """ +<<<<<<< ORIGINAL +def hello(): +======= +def hello_world(): +>>>>>>> UPDATED + +<<<<<<< ORIGINAL + hello() +======= + hello_world() +>>>>>>> UPDATED +""" + result = apply(codeblock, content) + assert " hello_world()" in result + + +def test_apply_with_placeholders(): + # tests multiple patches in a single codeblock, with placeholders in patches + # checks that whitespace is preserved + content = """ +def hello(): + print("hello") +""" + codeblock = """ +<<<<<<< ORIGINAL +def hello(): + # ... +======= +def hello_world(): + # ... +>>>>>>> UPDATED +""" + result = apply(codeblock, content) + assert "hello_world()" in result