Skip to content

Commit

Permalink
fix: expanduser() in patch tool, output warning suggesting save tool …
Browse files Browse the repository at this point in the history
…for inefficient patches
  • Loading branch information
ErikBjare committed Sep 27, 2024
1 parent 065d318 commit ca2d4e5
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Gives the LLM agent the ability to patch text files, by using a adapted version git conflict markers.
"""

# TODO: support multiple patches in one codeblock (or make it clear that only one patch per codeblock is supported/applied)

import re
from collections.abc import Generator
from pathlib import Path
Expand Down Expand Up @@ -114,18 +112,6 @@ def apply(codeblock: str, content: str) -> str:
return new_content


def apply_file(codeblock, filename):
if not Path(filename).exists():
raise ValueError(f"file not found: {filename}")

with open(filename, "r+") as f:
content = f.read()
result = apply(codeblock, content)
f.seek(0)
f.truncate()
f.write(result)


def execute_patch(
code: str, ask: bool, args: list[str]
) -> Generator[Message, None, None]:
Expand All @@ -134,15 +120,36 @@ def execute_patch(
"""
fn = " ".join(args)
assert fn, "No filename provided"
path = Path(fn).expanduser()
if not path.exists():
raise ValueError(f"file not found: {fn}")
if ask:
confirm = ask_execute(f"Apply patch to {fn}?")
if not confirm:
print("Patch not applied")
return

try:
apply_file(code, fn)
yield Message("system", f"Patch applied to {fn}")
with open(path) as f:
original_content = f.read()

# Apply the patch
patched_content = apply(code, original_content)
with open(path, "w") as f:
f.write(patched_content)

# Compare token counts
patch_tokens = len(code)
full_file_tokens = len(patched_content)

warnings = []
if full_file_tokens < patch_tokens:
warnings.append(
"Note: The patch was larger than the file. Consider using the save tool instead."
)
warnings_str = ("\n" + "\n".join(warnings)) if warnings else ""

yield Message("system", f"Patch applied to {fn}{warnings_str}")
except (ValueError, FileNotFoundError) as e:
yield Message("system", f"Patch failed: {e.args[0]}")

Expand Down

0 comments on commit ca2d4e5

Please sign in to comment.