Skip to content

Commit

Permalink
feat: migrate from readline to prompt_toolkit, with new features planned
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 1, 2024
1 parent 9ca2e08 commit 059b4d4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 12 deletions.
28 changes: 16 additions & 12 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .logmanager import Log, LogManager, prepare_messages
from .message import Message
from .models import get_model
from .readline import add_history
from .prompt import add_history, get_input
from .tools import ToolUse, execute_msg, has_tool
from .tools.base import ConfirmFunc
from .tools.browser import read_url
Expand Down Expand Up @@ -224,27 +224,31 @@ def prompt_user(value=None) -> str: # pragma: no cover
try:
set_interruptible()
response = prompt_input(PROMPT_USER, value)
if response:
add_history(response)
except KeyboardInterrupt:
print("\nInterrupted. Press Ctrl-D to exit.")
except EOFError:
print("\nGoodbye!")
sys.exit(0)
clear_interruptible()
if response:
add_history(response) # readline history
return response


def prompt_input(prompt: str, value=None) -> str: # pragma: no cover
"""Get input using prompt_toolkit with fish-style suggestions."""
prompt = prompt.strip() + ": "
if value:
console.print(prompt + value)
else:
prompt = rich_to_str(prompt, color_system="256")

# https://stackoverflow.com/a/53260487/965332
original_stdout = sys.stdout
sys.stdout = sys.__stdout__
value = input(prompt.strip() + " ")
sys.stdout = original_stdout
return value
return value
prompt = rich_to_str(prompt, color_system="256")

# TODO: Implement LLM suggestions
def get_suggestions(text: str) -> list[str]:
# This would be replaced with actual LLM suggestions
return []

return get_input(prompt, llm_suggest_callback=get_suggestions)


def _include_paths(msg: Message) -> Message:
Expand Down
4 changes: 4 additions & 0 deletions gptme/dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def get_readline_history_file() -> Path:
return get_config_dir() / "history"


def get_pt_history_file() -> Path:
return get_data_dir() / "history.pt"


def get_data_dir() -> Path:
# used in testing, so must take precedence
if "XDG_DATA_HOME" in os.environ:
Expand Down
99 changes: 99 additions & 0 deletions gptme/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from collections.abc import Callable

from prompt_toolkit import PromptSession
from prompt_toolkit.completion import Completer, Completion, PathCompleter
from prompt_toolkit.formatted_text import ANSI, HTML, to_formatted_text
from prompt_toolkit.history import FileHistory

from .commands import COMMANDS
from .dirs import get_pt_history_file

logger = logging.getLogger(__name__)


class GptmeCompleter(Completer):
"""Completer that combines command, path and LLM suggestions."""

def __init__(self, llm_suggest_callback: Callable[[str], list[str]] | None = None):
self.path_completer = PathCompleter()
self.llm_suggest_callback = llm_suggest_callback

def get_completions(self, document, complete_event):
document.get_word_before_cursor()
text = document.text_before_cursor

# Command completion
if text.startswith("/"):
cmd_text = text[1:]
for cmd in COMMANDS:
if cmd.startswith(cmd_text):
yield Completion(
cmd,
start_position=-len(cmd_text),
display=HTML(f"<blue>/{cmd}</blue>"),
)

# Path completion
elif any(text.startswith(prefix) for prefix in ["../", "~/", "./"]):
yield from self.path_completer.get_completions(document, complete_event)

# LLM suggestions
elif self.llm_suggest_callback and len(text) > 2:
try:
suggestions = self.llm_suggest_callback(text)
if suggestions:
for suggestion in suggestions:
if suggestion.startswith(text):
yield Completion(
suggestion,
start_position=-len(text),
display_meta="AI suggestion",
)
except Exception:
# Fail silently if LLM suggestions timeout/fail
pass


def create_prompt_session(
llm_suggest_callback: Callable[[str], list[str]] | None = None,
) -> PromptSession:
"""Create a PromptSession with history and completion support."""
history = FileHistory(str(get_pt_history_file()))
completer = GptmeCompleter(llm_suggest_callback)

return PromptSession(
history=history,
completer=completer,
complete_while_typing=True,
enable_history_search=True,
)


def get_input(
prompt: str = "Human: ",
llm_suggest_callback: Callable[[str], list[str]] | None = None,
) -> str:
"""Get input from user with completion support."""
session = create_prompt_session(llm_suggest_callback)
try:
logger.debug(f"Original prompt: {repr(prompt)}")

# https://stackoverflow.com/a/53260487/965332
# original_stdout = sys.stdout
# sys.stdout = sys.__stdout__
# value = input(prompt.strip() + " ")
result = session.prompt(to_formatted_text(ANSI(prompt.rstrip() + " ")))
# sys.stdout = original_stdout
return result
except (EOFError, KeyboardInterrupt) as e:
# Re-raise EOFError to handle Ctrl+D properly
if isinstance(e, EOFError):
raise
return ""


def add_history(line: str) -> None:
"""Add a line to the prompt_toolkit history."""
session = create_prompt_session()
session.history.append_string(line)

0 comments on commit 059b4d4

Please sign in to comment.