From 9627b73bdd372c045fedcb4ddfca7c610beb6b22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Wed, 2 Oct 2024 18:59:18 +0200 Subject: [PATCH] fix: more patch tool refactor --- gptme/tools/patch.py | 61 +++++++++++++++++++-------------------- tests/test_tools_patch.py | 4 +++ 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index d9e3a4ad..020a217e 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -65,7 +65,14 @@ class Patch: updated: str def apply(self, content: str) -> str: - return content.replace(self.original, self.updated, 1) + if self.original not in content: + raise ValueError("original chunk not found in file") + if content.count(self.original) > 1: + raise ValueError("original chunk not unique") + new_content = content.replace(self.original, self.updated, 1) + if new_content == content: + raise ValueError("patch did not change the file") + return new_content def diff_minimal(self, strip_context=False) -> str: """ @@ -81,8 +88,6 @@ def diff_minimal(self, strip_context=False) -> str: self.original.splitlines(), self.updated.splitlines(), lineterm="", - fromfile="original", - tofile="updated", ) )[3:] if strip_context: @@ -96,12 +101,11 @@ def diff_minimal(self, strip_context=False) -> str: markers[::-1].index("+") if "+" in markers else len(markers), markers[::-1].index("-") if "-" in markers else len(markers), ) - len(diff) - start - end diff = diff[start : len(diff) - end] return "\n".join(diff) @classmethod - def from_codeblock(cls, codeblock: str) -> Generator["Patch", None, None]: + def _from_codeblock(cls, codeblock: str) -> Generator["Patch", None, None]: codeblock = codeblock.strip() # Split the codeblock into multiple patches @@ -124,6 +128,25 @@ def from_codeblock(cls, codeblock: str) -> Generator["Patch", None, None]: _, original, modified, _ = parts yield Patch(original, modified) + @classmethod + def from_codeblock(cls, codeblock: str) -> Generator["Patch", None, None]: + for patch in cls._from_codeblock(codeblock): + original, updated = patch.original, patch.updated + re_placeholder = re.compile(r"^[ \t]*(#|//|\") \.\.\. ?.*$", re.MULTILINE) + if re_placeholder.search(original) or re_placeholder.search(updated): + originals = re_placeholder.split(original) + modifieds = re_placeholder.split(updated) + if len(originals) != len(modifieds): + raise ValueError( + "different number of placeholders in original and modified chunks" + ) + for orig, mod in zip(originals, modifieds): + if orig == mod: + continue + yield Patch(orig, mod) + else: + yield patch + def apply(codeblock: str, content: str) -> str: """ @@ -131,33 +154,7 @@ def apply(codeblock: str, content: str) -> str: """ new_content = content for patch in Patch.from_codeblock(codeblock): - original, updated = patch.original, patch.updated - re_placeholder = re.compile(r"^[ \t]*(#|//|\") \.\.\. ?.*$", re.MULTILINE) - if re_placeholder.search(original) or re_placeholder.search(updated): - # if placeholder found in content, then we cannot use placeholder-aware patching - if re_placeholder.search(content): - raise ValueError( - "placeholders found in content, cannot use placeholder-aware patching" - ) - - originals = re_placeholder.split(original) - modifieds = re_placeholder.split(updated) - if len(originals) != len(modifieds): - raise ValueError( - "different number of placeholders in original and modified chunks" - ) - for orig, mod in zip(originals, modifieds): - if orig == mod: - continue - new_content = Patch(orig, mod).apply(new_content) - else: - if original not in new_content: # pragma: no cover - raise ValueError("original chunk not found in file") - new_content = patch.apply(new_content) - - if new_content == content: # pragma: no cover - raise ValueError("patch did not change the file") - + new_content = patch.apply(new_content) return new_content diff --git a/tests/test_tools_patch.py b/tests/test_tools_patch.py index c0f82616..232d7a1a 100644 --- a/tests/test_tools_patch.py +++ b/tests/test_tools_patch.py @@ -85,9 +85,13 @@ def hello(): codeblock = """ <<<<<<< ORIGINAL + + ======= + + >>>>>>> UPDATED """ result = apply(codeblock, content)