Skip to content

Commit

Permalink
refactor: refactor how confirmation works, enabling LLM-guided confir…
Browse files Browse the repository at this point in the history
…mation and simplifying confirmation support in server
  • Loading branch information
ErikBjare committed Oct 15, 2024
1 parent d54df51 commit b843e88
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 180 deletions.
25 changes: 15 additions & 10 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from .message import Message
from .models import get_model
from .tools import ToolUse, execute_msg, has_tool
from .tools.base import ConfirmFunc
from .tools.browser import read_url
from .util import (
ask_execute,
console,
path_with_tilde,
print_bell,
Expand Down Expand Up @@ -89,6 +91,11 @@ def chat(
manager.log.print(show_hidden=show_hidden)
console.print("--- ^^^ past messages ^^^ ---")

def confirm_func(msg) -> bool:
if no_confirm:
return True
return ask_execute(msg)

# main loop
while True:
# if prompt_msgs given, process each prompt fully before moving to the next
Expand All @@ -99,16 +106,14 @@ def chat(
msg = _include_paths(msg)
manager.append(msg)
# if prompt is a user-command, execute it
if execute_cmd(msg, manager):
if execute_cmd(msg, manager, confirm_func):
continue

# Generate and execute response for this prompt
while True:
set_interruptible()
try:
response_msgs = list(
step(manager.log, no_confirm, stream=stream)
)
response_msgs = list(step(manager.log, stream, confirm_func))
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
manager.append(Message("system", "Interrupted"))
Expand All @@ -120,7 +125,7 @@ def chat(
manager.append(response_msg)
# run any user-commands, if msg is from user
if response_msg.role == "user" and execute_cmd(
response_msg, manager
response_msg, manager, confirm_func
):
break

Expand Down Expand Up @@ -153,17 +158,17 @@ def chat(

# ask for input if no prompt, generate reply, and run tools
clear_interruptible() # Ensure we're not interruptible during user input
for msg in step(manager.log, no_confirm, stream=stream): # pragma: no cover
for msg in step(manager.log, stream, confirm_func): # pragma: no cover
manager.append(msg)
# run any user-commands, if msg is from user
if msg.role == "user" and execute_cmd(msg, manager):
if msg.role == "user" and execute_cmd(msg, manager, confirm_func):
break


def step(
log: Log | list[Message],
no_confirm: bool,
stream: bool = True,
stream: bool,
confirm: ConfirmFunc,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""
if isinstance(log, list):
Expand Down Expand Up @@ -200,7 +205,7 @@ def step(
# log response and run tools
if msg_response:
yield msg_response.replace(quiet=True)
yield from execute_msg(msg_response, ask=not no_confirm)
yield from execute_msg(msg_response, confirm=confirm)
except KeyboardInterrupt:
clear_interruptible()
yield Message("system", "Interrupted")
Expand Down
28 changes: 14 additions & 14 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
)
from .models import get_model
from .tools import ToolUse, execute_msg, loaded_tools
from .tools.base import ConfirmFunc
from .useredit import edit_text_with_editor
from .util import ask_execute

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,21 +54,23 @@
COMMANDS = list(action_descriptions.keys())


def execute_cmd(msg: Message, log: LogManager) -> bool:
def execute_cmd(msg: Message, log: LogManager, confirm: ConfirmFunc) -> bool:
"""Executes any user-command, returns True if command was executed."""
assert msg.role == "user"

# if message starts with ., treat as command
# when command has been run,
if msg.content[:1] in ["/"]:
for resp in handle_cmd(msg.content, log, no_confirm=True):
for resp in handle_cmd(msg.content, log, confirm):
log.append(resp)
return True
return False


def handle_cmd(
cmd: str, manager: LogManager, no_confirm: bool
cmd: str,
manager: LogManager,
confirm: ConfirmFunc,
) -> Generator[Message, None, None]:
"""Handles a command."""
cmd = cmd.lstrip("/")
Expand All @@ -85,7 +87,7 @@ def handle_cmd(
# rename the conversation
print("Renaming conversation (enter empty name to auto-generate)")
new_name = args[0] if args else input("New name: ")
rename(manager, new_name, ask=not no_confirm)
rename(manager, new_name, confirm)
case "fork":
# fork the conversation
new_name = args[0] if args else input("New name: ")
Expand Down Expand Up @@ -116,13 +118,13 @@ def handle_cmd(
print("Replaying conversation...")
for msg in manager.log:
if msg.role == "assistant":
for reply_msg in execute_msg(msg, ask=True):
for reply_msg in execute_msg(msg, confirm):
print_msg(reply_msg, oneline=False)
case "impersonate":
content = full_args if full_args else input("[impersonate] Assistant: ")
msg = Message("assistant", content)
yield msg
yield from execute_msg(msg, ask=not no_confirm)
yield from execute_msg(msg, confirm)
case "tokens":
manager.undo(1, quiet=True)
n_tokens = len_tokens(manager.log.messages)
Expand All @@ -146,7 +148,7 @@ def handle_cmd(
# the case for python, shell, and other block_types supported by tools
tooluse = ToolUse(name, [], full_args)
if tooluse.is_runnable:
yield from tooluse.execute(ask=not no_confirm)
yield from tooluse.execute(confirm)
else:
if manager.log[-1].content.strip() == "/help":
# undo the '/help' command itself
Expand Down Expand Up @@ -176,16 +178,14 @@ def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no c
print("Applied edited messages, write /log to see the result")


def rename(manager: LogManager, new_name: str, ask: bool = True):
def rename(manager: LogManager, new_name: str, confirm: ConfirmFunc) -> None:
if new_name in ["", "auto"]:
new_name = llm.generate_name(prepare_messages(manager.log.messages))
assert " " not in new_name
print(f"Generated name: {new_name}")
if ask:
confirm = ask_execute("Confirm?")
if not confirm:
print("Aborting")
return
if not confirm("Confirm?"):
print("Aborting")
return
manager.rename(new_name, keep_date=True)
else:
manager.rename(new_name, keep_date=False)
Expand Down
3 changes: 3 additions & 0 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def to_xml(self) -> str:
def format(self, oneline: bool = False, highlight: bool = False) -> str:
return format_msgs([self], oneline=oneline, highlight=highlight)[0]

def print(self, oneline: bool = False, highlight: bool = True) -> None:
print_msg(self, oneline=oneline, highlight=highlight)

def to_toml(self) -> str:
"""Converts a message to a TOML string, for easy editing by hand in editor to then be parsed back."""
flags = []
Expand Down
9 changes: 7 additions & 2 deletions gptme/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def api_conversation_post(logfile: str):
return {"status": "ok"}


# TODO: add support for confirmation
def confirm_func(msg: str) -> bool:
return True


# generate response
@api.route("/api/conversations/<path:logfile>/generate", methods=["POST"])
def api_conversation_generate(logfile: str):
Expand All @@ -100,7 +105,7 @@ def api_conversation_generate(logfile: str):
f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
resp = execute_cmd(manager.log[-1], manager)
resp = execute_cmd(manager.log[-1], manager, confirm_func)
print("Done capturing stdout.")
if resp:
manager.write()
Expand All @@ -121,7 +126,7 @@ def api_conversation_generate(logfile: str):
resp_msgs = []
manager.append(msg)
resp_msgs.append(msg)
for reply_msg in execute_msg(msg, ask=False):
for reply_msg in execute_msg(msg, confirm_func):
manager.append(reply_msg)
resp_msgs.append(reply_msg)

Expand Down
8 changes: 4 additions & 4 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import lru_cache

from ..message import Message
from .base import ToolSpec, ToolUse
from .base import ConfirmFunc, ToolSpec, ToolUse
from .browser import tool as browser_tool
from .chats import tool as chats_tool
from .gh import tool as gh_tool
Expand All @@ -12,12 +12,12 @@
from .python import tool as python_tool
from .read import tool as tool_read
from .save import tool_append, tool_save
from .screenshot import tool as screenshot_tool
from .shell import tool as shell_tool
from .subagent import tool as subagent_tool
from .tmux import tool as tmux_tool
from .vision import tool as vision_tool
from .youtube import tool as youtube_tool
from .screenshot import tool as screenshot_tool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,12 +82,12 @@ def load_tool(tool: ToolSpec) -> None:
loaded_tools.append(tool)


def execute_msg(msg: Message, ask: bool) -> Generator[Message, None, None]:
def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None, None]:
"""Uses any tools called in a message and returns the response."""
assert msg.role == "assistant", "Only assistant messages can be executed"

for tooluse in ToolUse.iter_from_content(msg.content):
yield from tooluse.execute(ask)
yield from tooluse.execute(confirm)


# Called often when checking streaming output for executable blocks,
Expand Down
15 changes: 12 additions & 3 deletions gptme/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,18 @@
exclusive_mode = False


class ConfirmFunc(Protocol):
def __call__(self, msg: str) -> bool: ...


def ask_confirm(msg: str) -> bool:
"""Asks the user for confirmation."""
return input(f"{msg} [y/n] ").lower().startswith("y")


class ExecuteFunc(Protocol):
def __call__(
self, code: str, ask: bool, args: list[str]
self, code: str, args: list[str], confirm: ConfirmFunc
) -> Generator[Message, None, None]: ...


Expand Down Expand Up @@ -88,15 +97,15 @@ class ToolUse:
content: str
start: int | None = None

def execute(self, ask: bool) -> Generator[Message, None, None]:
def execute(self, confirm: ConfirmFunc) -> Generator[Message, None, None]:
"""Executes a tool-use tag and returns the output."""
# noreorder
from . import get_tool # fmt: skip

tool = get_tool(self.tool)
if tool and tool.execute:
try:
yield from tool.execute(self.content, ask, self.args)
yield from tool.execute(self.content, self.args, confirm)
except Exception as e:
# if we are testing, raise the exception
if "pytest" in globals():
Expand Down
20 changes: 11 additions & 9 deletions gptme/tools/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from pathlib import Path

from ..message import Message
from ..util import ask_execute, print_preview
from .base import ToolSpec, ToolUse
from ..util import print_preview
from .base import ConfirmFunc, ToolSpec, ToolUse

instructions = f"""
To patch/modify files, we use an adapted version of git conflict markers.
Expand Down Expand Up @@ -153,7 +153,9 @@ def apply(codeblock: str, content: str) -> str:


def execute_patch(
code: str, ask: bool, args: list[str]
code: str,
args: list[str],
confirm: ConfirmFunc,
) -> Generator[Message, None, None]:
"""
Applies the patch.
Expand All @@ -175,13 +177,13 @@ def execute_patch(
yield Message("system", f"Patch failed: {e.args[0]}")
return

# TODO: display minimal patches
# TODO: include patch headers to delimit multiple patches
print_preview(patches_str, lang="diff")
if ask:
# TODO: display minimal patches
confirm = ask_execute(f"Apply patch to {fn}?")
if not confirm:
print("Patch not applied")
return

if not confirm(f"Apply patch to {fn}?"):
print("Patch not applied")
return

try:
with open(path) as f:
Expand Down
23 changes: 10 additions & 13 deletions gptme/tools/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)

from ..message import Message
from ..util import ask_execute, print_preview
from .base import ToolSpec, ToolUse
from ..util import print_preview
from .base import ConfirmFunc, ToolSpec, ToolUse

if TYPE_CHECKING:
from IPython.terminal.embed import InteractiveShellEmbed # fmt: skip
Expand Down Expand Up @@ -93,19 +93,16 @@ def _get_ipython():
return _ipython


def execute_python(code: str, ask: bool, args=None) -> Generator[Message, None, None]:
def execute_python(
code: str, args: list[str], confirm: ConfirmFunc = lambda _: True
) -> Generator[Message, None, None]:
"""Executes a python codeblock and returns the output."""
code = code.strip()
if ask:
print_preview(code, "python")
confirm = ask_execute()
print()
if not confirm:
# early return
yield Message("system", "Aborted, user chose not to run command.")
return
else:
print("Skipping confirmation")
print_preview(code, "python")
if not confirm(f"{code}\n\nExecute this code?"):
# early return
yield Message("system", "Aborted, user chose not to run command.")
return

# Create an IPython instance if it doesn't exist yet
_ipython = _get_ipython()
Expand Down
Loading

0 comments on commit b843e88

Please sign in to comment.