Skip to content

Commit

Permalink
fix: improvements after execute_with_confirmation refactor (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare authored Dec 8, 2024
1 parent 01d8052 commit 23f81cf
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 68 deletions.
21 changes: 21 additions & 0 deletions gptme/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from pathlib import Path
from textwrap import indent
from typing import (
Any,
Expand Down Expand Up @@ -451,3 +452,23 @@ def _to_json(self) -> str:
def _to_toolcall(self) -> str:
self._to_json()
return f"@{self.tool}: {json.dumps(self._to_params(), indent=2)}"


def get_path(
code: str | None, args: list[str] | None, kwargs: dict[str, str] | None
) -> Path:
"""Get the path from args/kwargs for save, append, and patch."""
if code is not None and args is not None:
fn = " ".join(args)
if (
fn.startswith("save ")
or fn.startswith("append ")
or fn.startswith("patch ")
):
fn = fn.split(" ", 1)[1]
elif kwargs is not None:
fn = kwargs.get("path", "")
else:
raise ValueError("No filename provided")

return Path(fn).expanduser()
23 changes: 4 additions & 19 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Parameter,
ToolSpec,
ToolUse,
get_path,
)

instructions = """
Expand Down Expand Up @@ -183,27 +184,11 @@ def apply(codeblock: str, content: str) -> str:
return new_content


def get_patch_path(
code: str | None, args: list[str] | None, kwargs: dict[str, str] | None
) -> Path:
"""Get the path from args/kwargs."""
if code is not None and args is not None:
fn = " ".join(args)
if not fn:
raise ValueError("No path provided")
elif kwargs is not None:
fn = kwargs.get("path", "")
else:
raise ValueError("No path provided")

return Path(fn).expanduser()


def preview_patch(content: str, path: Path | None) -> str | None:
"""Prepare preview content for patch operation."""
try:
patches = Patch.from_codeblock(content)
return "\n\n".join(p.diff_minimal() for p in patches)
return "\n@@@\n".join(p.diff_minimal() for p in patches)
except ValueError as e:
raise ValueError(f"Invalid patch: {e.args[0]}") from None

Expand Down Expand Up @@ -261,10 +246,10 @@ def execute_patch(
kwargs,
confirm,
execute_fn=execute_patch_impl,
get_path_fn=get_patch_path,
get_path_fn=get_path,
preview_fn=preview_patch,
preview_lang="diff",
confirm_msg=None, # use default
confirm_msg=f"Apply patch to {get_path(code, args, kwargs)}?",
allow_edit=True,
)

Expand Down
103 changes: 54 additions & 49 deletions gptme/tools/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Parameter,
ToolSpec,
ToolUse,
get_path,
)
from .patch import Patch

Expand Down Expand Up @@ -59,22 +60,6 @@ def examples_append(tool_format):
""".strip()


def get_save_path(
code: str | None, args: list[str] | None, kwargs: dict[str, str] | None
) -> Path:
"""Get the path from args/kwargs."""
if code is not None and args is not None:
fn = " ".join(args)
if fn.startswith("save "):
fn = fn[5:]
elif kwargs is not None:
fn = kwargs.get("path", "")
else:
raise ValueError("No filename provided")

return Path(fn).expanduser()


def preview_save(content: str, path: Path | None) -> str | None:
"""Prepare preview content for save operation."""
assert path
Expand All @@ -86,6 +71,19 @@ def preview_save(content: str, path: Path | None) -> str | None:
return content


def preview_append(content: str, path: Path | None) -> str | None:
"""Prepare preview content for append operation."""
assert path
if path.exists():
current = path.read_text()
if not current.endswith("\n"):
current += "\n"
else:
current = ""
new = current + content
return preview_save(new, path)


def execute_save_impl(
content: str, path: Path | None, confirm: ConfirmFunc
) -> Generator[Message, None, None]:
Expand Down Expand Up @@ -115,6 +113,32 @@ def execute_save_impl(
yield Message("system", f"Saved to {path}")


def execute_append_impl(
content: str, path: Path | None, confirm: ConfirmFunc
) -> Generator[Message, None, None]:
"""Actual append implementation."""
assert path
path_display = path
path = path.expanduser()
if not path.exists():
if not confirm(f"File {path_display} doesn't exist, create it?"):
yield Message("system", "Append cancelled.")
return

# strip leading newlines
# content = content.lstrip("\n")
# ensure it ends with a newline
if not content.endswith("\n"):
content += "\n"

before = path.read_text()
if not before.endswith("\n"):
content = "\n" + content
with open(path, "a") as f:
f.write(content)
yield Message("system", f"Appended to {path_display}")


def execute_save(
code: str | None,
args: list[str] | None,
Expand All @@ -128,10 +152,10 @@ def execute_save(
kwargs,
confirm,
execute_fn=execute_save_impl,
get_path_fn=get_save_path,
get_path_fn=get_path,
preview_fn=preview_save,
preview_lang="diff",
confirm_msg=None, # use default
confirm_msg=f"Save to {get_path(code, args, kwargs)}?",
allow_edit=True,
)

Expand All @@ -143,37 +167,18 @@ def execute_append(
confirm: ConfirmFunc,
) -> Generator[Message, None, None]:
"""Append code to a file."""

fn = ""
content = ""
if code is not None and args is not None:
fn = " ".join(args)
# strip leading newlines
content = code.lstrip("\n")
# ensure it ends with a newline
if not content.endswith("\n"):
content += "\n"
elif kwargs is not None:
content = kwargs["content"]
fn = kwargs["path"]

assert fn, "No filename provided"
assert content, "No content provided"

if not confirm(f"Append to {fn}?"):
# early return
yield Message("system", "Append cancelled.")
return

path = Path(fn).expanduser()

if not path.exists():
yield Message("system", f"File {fn} doesn't exist, can't append to it.")
return

with open(path, "a") as f:
f.write(content)
yield Message("system", f"Appended to {fn}")
yield from execute_with_confirmation(
code,
args,
kwargs,
confirm,
execute_fn=execute_append_impl,
get_path_fn=get_path,
preview_fn=preview_append,
preview_lang="diff",
confirm_msg=f"Append to {get_path(code, args, kwargs)}?",
allow_edit=True,
)


tool_save = ToolSpec(
Expand Down

0 comments on commit 23f81cf

Please sign in to comment.