Skip to content

Commit

Permalink
refactor: extracted chat and step from cli.py into new chat.py (#158
Browse files Browse the repository at this point in the history
)

* refactor: extracted `chat` and `step` from cli.py into new chat.py

* Apply suggestions from code review
  • Loading branch information
ErikBjare authored Oct 6, 2024
1 parent 1a7e4ef commit 5081c89
Show file tree
Hide file tree
Showing 5 changed files with 433 additions and 423 deletions.
378 changes: 378 additions & 0 deletions gptme/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,378 @@
import errno
import logging
import os
import re
import readline
import sys
import urllib.parse
from collections.abc import Generator
from pathlib import Path

from .commands import action_descriptions, execute_cmd
from .config import get_workspace_prompt
from .constants import PROMPT_USER
from .init import init
from .interrupt import clear_interruptible, set_interruptible
from .llm import reply
from .logmanager import LogManager
from .message import Message
from .models import get_model
from .tools import ToolUse, execute_msg, has_tool
from .tools.browser import read_url
from .util import (
console,
path_with_tilde,
print_bell,
rich_to_str,
)

logger = logging.getLogger(__name__)


def chat(
prompt_msgs: list[Message],
initial_msgs: list[Message],
logdir: Path,
model: str | None,
stream: bool = True,
no_confirm: bool = False,
interactive: bool = True,
show_hidden: bool = False,
workspace: Path | None = None,
tool_allowlist: list[str] | None = None,
):
"""
Run the chat loop.
prompt_msgs: list of messages to execute in sequence.
initial_msgs: list of history messages.
workspace: path to workspace directory, or @log to create one in the log directory.
Callable from other modules.
"""
# init
init(model, interactive, tool_allowlist)

if model and model.startswith("openai/o1") and stream:
logger.info("Disabled streaming for OpenAI's O1 (not supported)")
stream = False

console.log(f"Using logdir {path_with_tilde(logdir)}")
log = LogManager.load(
logdir, initial_msgs=initial_msgs, show_hidden=show_hidden, create=True
)

# change to workspace directory
# use if exists, create if @log, or use given path
log_workspace = logdir / "workspace"
if log_workspace.exists():
assert not workspace or (
workspace == log_workspace
), f"Workspace already exists in {log_workspace}, wont override."
workspace = log_workspace
else:
if not workspace:
workspace = Path.cwd()
assert workspace.exists(), f"Workspace path {workspace} does not exist"
console.log(f"Using workspace at {path_with_tilde(workspace)}")
os.chdir(workspace)

workspace_prompt = get_workspace_prompt(str(workspace))
# check if message is already in log, such as upon resume
if (
workspace_prompt
and workspace_prompt not in [m.content for m in log]
and "user" not in [m.role for m in log]
):
log.append(Message("system", workspace_prompt, hide=True, quiet=True))

# print log
log.print()
console.print("--- ^^^ past messages ^^^ ---")

# main loop
while True:
# if prompt_msgs given, process each prompt fully before moving to the next
if prompt_msgs:
while prompt_msgs:
msg = prompt_msgs.pop(0)
if not msg.content.startswith("/"):
msg = _include_paths(msg)
log.append(msg)
# if prompt is a user-command, execute it
if execute_cmd(msg, log):
continue

# Generate and execute response for this prompt
while True:
set_interruptible()
try:
response_msgs = list(step(log, no_confirm, stream=stream))
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
log.append(Message("system", "Interrupted"))
break
finally:
clear_interruptible()

for response_msg in response_msgs:
log.append(response_msg)
# run any user-commands, if msg is from user
if response_msg.role == "user" and execute_cmd(
response_msg, log
):
break

# Check if there are any runnable tools left
last_content = next(
(m.content for m in reversed(log) if m.role == "assistant"), ""
)
if not any(
tooluse.is_runnable
for tooluse in ToolUse.iter_from_content(last_content)
):
break

# All prompts processed, continue to next iteration
continue

# if:
# - prompts exhausted
# - non-interactive
# - no executable block in last assistant message
# then exit
elif not interactive:
logger.debug("Non-interactive and exhausted prompts, exiting")
break

# ask for input if no prompt, generate reply, and run tools
clear_interruptible() # Ensure we're not interruptible during user input
for msg in step(log, no_confirm, stream=stream): # pragma: no cover
log.append(msg)
# run any user-commands, if msg is from user
if msg.role == "user" and execute_cmd(msg, log):
break


def step(
log: LogManager,
no_confirm: bool,
stream: bool = True,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""
# If last message was a response, ask for input.
# If last message was from the user (such as from crash/edited log),
# then skip asking for input and generate response
last_msg = log[-1] if log else None
if (
not last_msg
or (last_msg.role in ["assistant"])
or last_msg.content == "Interrupted"
or last_msg.pinned
or not any(role == "user" for role in [m.role for m in log])
): # pragma: no cover
inquiry = prompt_user()
if not inquiry:
# Empty command, ask for input again
return
msg = Message("user", inquiry, quiet=True)
msg = _include_paths(msg)
yield msg

# generate response and run tools
set_interruptible()
try:
# performs reduction/context trimming, if necessary
msgs = log.prepare_messages()

for m in msgs:
logger.debug(f"Prepared message: {m}")

# generate response
msg_response = reply(msgs, get_model().model, stream)

# log response and run tools
if msg_response:
yield msg_response.replace(quiet=True)
yield from execute_msg(msg_response, ask=not no_confirm)
except KeyboardInterrupt:
clear_interruptible()
yield Message("system", "Interrupted")
finally:
clear_interruptible()


def prompt_user(value=None) -> str: # pragma: no cover
print_bell()
set_interruptible()
try:
response = prompt_input(PROMPT_USER, value)
except KeyboardInterrupt:
print("\nInterrupted. Press Ctrl-D to exit.")
return ""
clear_interruptible()
if response:
readline.add_history(response)
return response


def prompt_input(prompt: str, value=None) -> str: # pragma: no cover
prompt = prompt.strip() + ": "
if value:
console.print(prompt + value)
else:
prompt = rich_to_str(prompt, color_system="256")

# https://stackoverflow.com/a/53260487/965332
original_stdout = sys.stdout
sys.stdout = sys.__stdout__
value = input(prompt.strip() + " ")
sys.stdout = original_stdout
return value


def _include_paths(msg: Message) -> Message:
"""
Searches the message for any valid paths and:
- appends the contents of such files as codeblocks.
- include images as files.
"""
# TODO: add support for directories?
assert msg.role == "user"

# list the current directory
cwd_files = [f.name for f in Path.cwd().iterdir()]

# match absolute, home, relative paths, and URLs anywhere in the message
# could be wrapped with spaces or backticks, possibly followed by a question mark
# don't look in codeblocks, and don't match paths that are already in codeblocks
# TODO: this will misbehave if there are codeblocks (or triple backticks) in codeblocks
content_no_codeblocks = re.sub(r"```.*?\n```", "", msg.content, flags=re.DOTALL)
append_msg = ""
for word in re.split(r"[\s`]", content_no_codeblocks):
# remove wrapping backticks
word = word.strip("`")
# remove trailing question mark
word = word.rstrip("?")
if not word:
continue
if (
# if word starts with a path character
any(word.startswith(s) for s in ["/", "~/", "./"])
# or word is a URL
or word.startswith("http")
# or word is a file in the current dir,
# or a path that starts in a folder in the current dir
or any(word.split("/", 1)[0] == file for file in cwd_files)
):
logger.debug(f"potential path/url: {word=}")
contents = _parse_prompt(word)
if contents:
# if we found a valid path, replace it with the contents of the file
append_msg += "\n\n" + contents

file = _parse_prompt_files(word)
if file:
msg.files.append(file)

# append the message with the file contents
if append_msg:
msg = msg.replace(content=msg.content + append_msg)

return msg


def _parse_prompt(prompt: str) -> str | None:
"""
Takes a string that might be a path,
and if so, returns the contents of that file wrapped in a codeblock.
"""
# if prompt is a command, exit early (as commands might take paths as arguments)
if any(
prompt.startswith(command)
for command in [f"/{cmd}" for cmd in action_descriptions.keys()]
):
return None

try:
# check if prompt is a path, if so, replace it with the contents of that file
f = Path(prompt).expanduser()
if f.exists() and f.is_file():
return f"```{prompt}\n{Path(prompt).expanduser().read_text()}\n```"
except OSError as oserr:
# some prompts are too long to be a path, so we can't read them
if oserr.errno != errno.ENAMETOOLONG:
pass
raise
except UnicodeDecodeError:
# some files are not text files (images, audio, PDFs, binaries, etc), so we can't read them
# TODO: but can we handle them better than just printing the path? maybe with metadata from `file`?
# logger.warning(f"Failed to read file {prompt}: not a text file")
return None

# check if any word in prompt is a path or URL,
# if so, append the contents as a code block
words = prompt.split()
paths = []
urls = []
for word in words:
f = Path(word).expanduser()
if f.exists() and f.is_file():
paths.append(word)
continue
try:
p = urllib.parse.urlparse(word)
if p.scheme and p.netloc:
urls.append(word)
except ValueError:
pass

result = ""
if paths or urls:
result += "\n\n"
if paths:
logger.debug(f"{paths=}")
if urls:
logger.debug(f"{urls=}")
for path in paths:
result += _parse_prompt(path) or ""

if not has_tool("browser"):
logger.warning("Browser tool not available, skipping URL read")
else:
for url in urls:
try:
content = read_url(url)
result += f"```{url}\n{content}\n```"
except Exception as e:
logger.warning(f"Failed to read URL {url}: {e}")

return result


def _parse_prompt_files(prompt: str) -> Path | None:
"""
Takes a string that might be a image path or PDF, to be attached to the message, and returns the path.
"""
allowed_exts = ["png", "jpg", "jpeg", "gif", "pdf"]

# if prompt is a command, exit early (as commands might take paths as arguments)
if any(
prompt.startswith(command)
for command in [f"/{cmd}" for cmd in action_descriptions.keys()]
):
return None

try:
# check if prompt is a path, if so, replace it with the contents of that file
p = Path(prompt)
if p.exists() and p.is_file() and p.suffix[1:] in allowed_exts:
logger.warning("Attaching file to message")
return p
else:
return None
except OSError as oserr: # pragma: no cover
# some prompts are too long to be a path, so we can't read them
if oserr.errno != errno.ENAMETOOLONG:
return None
raise
Loading

0 comments on commit 5081c89

Please sign in to comment.