Skip to content

Commit

Permalink
fix: support multiple patches in a single codeblock (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare authored Sep 16, 2024
1 parent 0787f59 commit ae3ea89
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 52 deletions.
85 changes: 41 additions & 44 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,54 +50,51 @@ def hello():

def apply(codeblock: str, content: str) -> str:
"""
Applies the patch in ``codeblock`` to ``content``.
Applies multiple patches in ``codeblock`` to ``content``.
"""
# TODO: support multiple patches in one codeblock,
# or make it clear that only one patch per codeblock is supported
codeblock = codeblock.strip()

# get the original and modified chunks
if ORIGINAL not in codeblock: # pragma: no cover
raise ValueError(f"invalid patch, no `{ORIGINAL.strip()}`", codeblock)
original = re.split(ORIGINAL, codeblock)[1]

if DIVIDER not in original: # pragma: no cover
raise ValueError(f"invalid patch, no `{DIVIDER.strip()}`", codeblock)
original, modified = re.split(DIVIDER, original)

if UPDATED not in "\n" + modified: # pragma: no cover
raise ValueError(f"invalid patch, no `{UPDATED.strip()}`", codeblock)
modified = re.split(UPDATED, modified)[0]

# TODO: maybe allow modified chunk to contain "// ..." to refer to chunks in the original,
# and then replace these with the original chunks?
re_placeholder = re.compile(r"^[ \t]*(#|//) \.\.\. ?.*$", re.MULTILINE)
if re_placeholder.search(original) or re_placeholder.search(modified):
# raise ValueError("placeholders in modified chunk")
# split them by lines starting with "# ..."
originals = re_placeholder.split(original)
modifieds = re_placeholder.split(modified)
if len(originals) != len(modifieds):
raise ValueError(
"different number of placeholders in original and modified chunks"
f"\n{originals}\n{modifieds}"
)
new = content
for orig, mod in zip(originals, modifieds):
if orig == mod:
continue
new = new.replace(orig, mod)
else:
if original not in content: # pragma: no cover
raise ValueError("original chunk not found in file", original)

# replace the original chunk with the modified chunk
new = content.replace(original, modified)

if new == content: # pragma: no cover
new_content = content

# Split the codeblock into multiple patches
patches = re.split(f"(?={re.escape(ORIGINAL)})", codeblock)

for patch in patches:
if not patch.strip():
continue

if ORIGINAL not in patch: # pragma: no cover
raise ValueError(f"invalid patch, no `{ORIGINAL.strip()}`", patch)

parts = re.split(
f"{re.escape(ORIGINAL)}|{re.escape(DIVIDER)}|{re.escape(UPDATED)}", patch
)
if len(parts) != 4: # pragma: no cover
raise ValueError("invalid patch format", patch)

_, original, modified, _ = parts

re_placeholder = re.compile(r"^[ \t]*(#|//) \.\.\. ?.*$", re.MULTILINE)
if re_placeholder.search(original) or re_placeholder.search(modified):
originals = re_placeholder.split(original)
modifieds = re_placeholder.split(modified)
if len(originals) != len(modifieds):
raise ValueError(
"different number of placeholders in original and modified chunks"
f"\n{originals}\n{modifieds}"
)
for orig, mod in zip(originals, modifieds):
if orig == mod:
continue
new_content = new_content.replace(orig, mod)
else:
if original not in new_content: # pragma: no cover
raise ValueError("original chunk not found in file", original)
new_content = new_content.replace(original, modified)

if new_content == content: # pragma: no cover
raise ValueError("patch did not change the file")

return new
return new_content


def apply_file(codeblock, filename):
Expand Down
55 changes: 47 additions & 8 deletions tests/test_tools_patch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from gptme.tools.patch import apply

example_patch = """
```patch filename.py
<<<<<<< ORIGINAL
original lines
=======
modified lines
>>>>>>> UPDATED
```
"""


Expand All @@ -28,15 +26,13 @@ def hello():
"""

codeblock = """
```patch test.py
<<<<<<< ORIGINAL
def hello():
print("hello")
=======
def hello(name="world"):
print(f"hello {name}")
>>>>>>> UPDATED
```
"""

result = apply(codeblock, content)
Expand All @@ -60,14 +56,12 @@ def hello():

# NOTE: test fails if UPDATED block doesn't have an empty line
codeblock = """
```patch test.py
<<<<<<< ORIGINAL
def hello():
print("hello")
=======
>>>>>>> UPDATED
```
"""
print(content)
result = apply(codeblock, content)
Expand All @@ -89,14 +83,59 @@ def hello():
hello()
"""
codeblock = """
```patch test.py
<<<<<<< ORIGINAL
=======
>>>>>>> UPDATED
```
"""
result = apply(codeblock, content)
assert "\n\n\n" in result


def test_apply_multiple():
# tests multiple patches in a single codeblock, with placeholders in patches
# checks that whitespace is preserved
content = """
def hello():
print("hello")
if __name__ == "__main__":
hello()
"""
codeblock = """
<<<<<<< ORIGINAL
def hello():
=======
def hello_world():
>>>>>>> UPDATED
<<<<<<< ORIGINAL
hello()
=======
hello_world()
>>>>>>> UPDATED
"""
result = apply(codeblock, content)
assert " hello_world()" in result


def test_apply_with_placeholders():
# tests multiple patches in a single codeblock, with placeholders in patches
# checks that whitespace is preserved
content = """
def hello():
print("hello")
"""
codeblock = """
<<<<<<< ORIGINAL
def hello():
# ...
=======
def hello_world():
# ...
>>>>>>> UPDATED
"""
result = apply(codeblock, content)
assert "hello_world()" in result

0 comments on commit ae3ea89

Please sign in to comment.