From 059b4d49e2da96325474300384aa5c40a6e93050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Fri, 1 Nov 2024 23:10:32 +0100 Subject: [PATCH] feat: migrate from readline to prompt_toolkit, with new features planned --- gptme/chat.py | 28 ++++++++------ gptme/dirs.py | 4 ++ gptme/prompt.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 12 deletions(-) create mode 100644 gptme/prompt.py diff --git a/gptme/chat.py b/gptme/chat.py index 35b1be8f..79a5cfa9 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -17,7 +17,7 @@ from .logmanager import Log, LogManager, prepare_messages from .message import Message from .models import get_model -from .readline import add_history +from .prompt import add_history, get_input from .tools import ToolUse, execute_msg, has_tool from .tools.base import ConfirmFunc from .tools.browser import read_url @@ -224,27 +224,31 @@ def prompt_user(value=None) -> str: # pragma: no cover try: set_interruptible() response = prompt_input(PROMPT_USER, value) + if response: + add_history(response) except KeyboardInterrupt: print("\nInterrupted. Press Ctrl-D to exit.") + except EOFError: + print("\nGoodbye!") + sys.exit(0) clear_interruptible() - if response: - add_history(response) # readline history return response def prompt_input(prompt: str, value=None) -> str: # pragma: no cover + """Get input using prompt_toolkit with fish-style suggestions.""" prompt = prompt.strip() + ": " if value: console.print(prompt + value) - else: - prompt = rich_to_str(prompt, color_system="256") - - # https://stackoverflow.com/a/53260487/965332 - original_stdout = sys.stdout - sys.stdout = sys.__stdout__ - value = input(prompt.strip() + " ") - sys.stdout = original_stdout - return value + return value + prompt = rich_to_str(prompt, color_system="256") + + # TODO: Implement LLM suggestions + def get_suggestions(text: str) -> list[str]: + # This would be replaced with actual LLM suggestions + return [] + + return get_input(prompt, llm_suggest_callback=get_suggestions) def _include_paths(msg: Message) -> Message: diff --git a/gptme/dirs.py b/gptme/dirs.py index 8af098b9..7e6fd5b5 100644 --- a/gptme/dirs.py +++ b/gptme/dirs.py @@ -13,6 +13,10 @@ def get_readline_history_file() -> Path: return get_config_dir() / "history" +def get_pt_history_file() -> Path: + return get_data_dir() / "history.pt" + + def get_data_dir() -> Path: # used in testing, so must take precedence if "XDG_DATA_HOME" in os.environ: diff --git a/gptme/prompt.py b/gptme/prompt.py new file mode 100644 index 00000000..1b32cd26 --- /dev/null +++ b/gptme/prompt.py @@ -0,0 +1,99 @@ +import logging +from collections.abc import Callable + +from prompt_toolkit import PromptSession +from prompt_toolkit.completion import Completer, Completion, PathCompleter +from prompt_toolkit.formatted_text import ANSI, HTML, to_formatted_text +from prompt_toolkit.history import FileHistory + +from .commands import COMMANDS +from .dirs import get_pt_history_file + +logger = logging.getLogger(__name__) + + +class GptmeCompleter(Completer): + """Completer that combines command, path and LLM suggestions.""" + + def __init__(self, llm_suggest_callback: Callable[[str], list[str]] | None = None): + self.path_completer = PathCompleter() + self.llm_suggest_callback = llm_suggest_callback + + def get_completions(self, document, complete_event): + document.get_word_before_cursor() + text = document.text_before_cursor + + # Command completion + if text.startswith("/"): + cmd_text = text[1:] + for cmd in COMMANDS: + if cmd.startswith(cmd_text): + yield Completion( + cmd, + start_position=-len(cmd_text), + display=HTML(f"/{cmd}"), + ) + + # Path completion + elif any(text.startswith(prefix) for prefix in ["../", "~/", "./"]): + yield from self.path_completer.get_completions(document, complete_event) + + # LLM suggestions + elif self.llm_suggest_callback and len(text) > 2: + try: + suggestions = self.llm_suggest_callback(text) + if suggestions: + for suggestion in suggestions: + if suggestion.startswith(text): + yield Completion( + suggestion, + start_position=-len(text), + display_meta="AI suggestion", + ) + except Exception: + # Fail silently if LLM suggestions timeout/fail + pass + + +def create_prompt_session( + llm_suggest_callback: Callable[[str], list[str]] | None = None, +) -> PromptSession: + """Create a PromptSession with history and completion support.""" + history = FileHistory(str(get_pt_history_file())) + completer = GptmeCompleter(llm_suggest_callback) + + return PromptSession( + history=history, + completer=completer, + complete_while_typing=True, + enable_history_search=True, + ) + + +def get_input( + prompt: str = "Human: ", + llm_suggest_callback: Callable[[str], list[str]] | None = None, +) -> str: + """Get input from user with completion support.""" + session = create_prompt_session(llm_suggest_callback) + try: + logger.debug(f"Original prompt: {repr(prompt)}") + + # https://stackoverflow.com/a/53260487/965332 + # original_stdout = sys.stdout + # sys.stdout = sys.__stdout__ + # value = input(prompt.strip() + " ") + result = session.prompt(to_formatted_text(ANSI(prompt.rstrip() + " "))) + # sys.stdout = original_stdout + return result + except (EOFError, KeyboardInterrupt) as e: + # Re-raise EOFError to handle Ctrl+D properly + if isinstance(e, EOFError): + raise + return "" + + +def add_history(line: str) -> None: + """Add a line to the prompt_toolkit history.""" + session = create_prompt_session() + session.history.append_string(line)