Skip to content

Commit

Permalink
fix: fixed prompt chaining, added test (fixes #106)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Sep 20, 2024
1 parent 838a898 commit deac8db
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 24 deletions.
64 changes: 40 additions & 24 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@


docstring = f"""
GPTMe, a chat-CLI for LLMs, enabling them to execute commands and code.
gptme is a chat-CLI for LLMs, empowering them with tools to run shell commands, execute code, read and manipulate files, and more.
If PROMPTS are provided, a new conversation will be started with it.
If one of the PROMPTS is '{MULTIPROMPT_SEPARATOR}', following prompts will run after the assistant is done answering the first one.
If one of the PROMPTS is '{MULTIPROMPT_SEPARATOR}', the PROMPTS will form a chain,
where following prompts will be submitted after the assistant is done answering the previous one.
The interface provides user commands that can be used to interact with the system.
Expand Down Expand Up @@ -176,7 +177,7 @@ def main(
if resume:
name = "resume" # magic string to load last conversation

# join prompts, grouped by `-` if present, since that's the separator for multiple-round prompts
# join prompts, grouped by `-` if present, since that's the separator for "chained"/multiple-round prompts
sep = "\n\n" + MULTIPROMPT_SEPARATOR
prompts = [p.strip() for p in "\n\n".join(prompts).split(sep) if p]
prompt_msgs = [Message("user", p) for p in prompts]
Expand Down Expand Up @@ -260,34 +261,49 @@ def chat(

# main loop
while True:
# if prompt_msgs given, insert next prompt into log
# if prompt_msgs given, process each prompt fully before moving to the next
if prompt_msgs:
msg = prompt_msgs.pop(0)
if not msg.content.startswith("/"):
msg = _include_paths(msg)
log.append(msg)
# if prompt is a user-command, execute it
if execute_cmd(msg, log):
continue
while prompt_msgs:
msg = prompt_msgs.pop(0)
if not msg.content.startswith("/"):
msg = _include_paths(msg)
log.append(msg)
# if prompt is a user-command, execute it
if execute_cmd(msg, log):
continue

# Generate and execute response for this prompt
while True:
response_msgs = list(step(log, no_confirm, stream=stream))
for response_msg in response_msgs:
log.append(response_msg)
# run any user-commands, if msg is from user
if response_msg.role == "user" and execute_cmd(
response_msg, log
):
break

# Check if there are any runnable tools left
last_content = next(
(m.content for m in reversed(log) if m.role == "assistant"), ""
)
if not any(
tooluse.is_runnable
for tooluse in ToolUse.iter_from_content(last_content)
):
break

# All prompts processed, continue to next iteration
continue

# if:
# - prompts exhausted
# - non-interactive
# - no executable block in last assistant message
# then exit
elif not interactive:
# noreorder

# continue if we can run tools on the last message
last_content = next(
(m.content for m in reversed(log) if m.role == "assistant"), ""
)
tooluses = list(ToolUse.iter_from_content(last_content))
runnable = False
if tooluses:
runnable = any(tooluse.is_runnable for tooluse in tooluses)
if not runnable:
logger.debug("Non-interactive and exhausted prompts, exiting")
break
logger.debug("Non-interactive and exhausted prompts, exiting")
break

# ask for input if no prompt, generate reply, and run tools
for msg in step(log, no_confirm, stream=stream): # pragma: no cover
Expand Down
23 changes: 23 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,29 @@ def test_stdin(args: list[str], runner: CliRunner):
assert result.exit_code == 0


@pytest.mark.slow
def test_chain(args: list[str], runner: CliRunner):
"""tests that the "-" argument works to chain commands, executing after the agent has exhausted the previous command"""
# first command needs to be something requiring two tools, so we can check both are ran before the next chained command
args.append("write a test.txt file, then patch it")
args.append("-")
args.append("read the contents")
result = runner.invoke(gptme.cli.main, args)
print(result.output)
# check that outputs came in expected order
user1_loc = result.output.index("User:")
user2_loc = result.output.index("User:", user1_loc + 1)
save_loc = result.output.index("```save")
patch_loc = result.output.index("```patch")
print_loc = result.output.rindex("cat test.txt")
print(f"{user1_loc=} {save_loc=} {patch_loc=} {user2_loc=} {print_loc=}")
assert user1_loc < user2_loc
assert save_loc < patch_loc
assert patch_loc < user2_loc
assert user2_loc < print_loc
assert result.exit_code == 0


# TODO: move elsewhere
@pytest.mark.slow
def test_tmux(args: list[str], runner: CliRunner):
Expand Down

0 comments on commit deac8db

Please sign in to comment.