diff --git a/gptme/cli.py b/gptme/cli.py index 81dc1b40..76df2755 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -4,7 +4,9 @@ import os import re import readline # noqa: F401 +import signal import sys +import time import urllib.parse from collections.abc import Generator from datetime import datetime @@ -185,6 +187,9 @@ def main( prompts = [p.strip() for p in "\n\n".join(prompts).split(sep) if p] prompt_msgs = [Message("user", p) for p in prompts] + # register a handler for Ctrl-C + signal.signal(signal.SIGINT, handle_keyboard_interrupt) + chat( prompt_msgs, initial_msgs, @@ -198,6 +203,44 @@ def main( ) +# Set up a KeyboardInterrupt handler to handle Ctrl-C during the chat loop +interruptible = False +last_interrupt_time = 0.0 + + +def handle_keyboard_interrupt(signum, frame): + """ + This handler allows interruption of the assistant or tool execution when in an interruptible state, + while still providing a safeguard against accidental exits during user input. + """ + global last_interrupt_time + current_time = time.time() + timeout = 1 + + if interruptible: + raise KeyboardInterrupt + + if current_time - last_interrupt_time <= timeout: + console.log("Second interrupt received, exiting...") + sys.exit(0) + + last_interrupt_time = current_time + console.print() + console.log( + f"Interrupt received. Press Ctrl-C again within {timeout} seconds to exit." + ) + + +def set_interruptible(): + global interruptible + interruptible = True + + +def clear_interruptible(): + global interruptible + interruptible = False + + def chat( prompt_msgs: list[Message], initial_msgs: list[Message], @@ -277,7 +320,16 @@ def chat( # Generate and execute response for this prompt while True: - response_msgs = list(step(log, no_confirm, stream=stream)) + set_interruptible() + try: + response_msgs = list(step(log, no_confirm, stream=stream)) + except KeyboardInterrupt: + console.log("Interrupted. Stopping current execution.") + log.append(Message("system", "Interrupted")) + break + finally: + clear_interruptible() + for response_msg in response_msgs: log.append(response_msg) # run any user-commands, if msg is from user @@ -309,6 +361,7 @@ def chat( break # ask for input if no prompt, generate reply, and run tools + clear_interruptible() # Ensure we're not interruptible during user input for msg in step(log, no_confirm, stream=stream): # pragma: no cover log.append(msg) # run any user-commands, if msg is from user @@ -341,7 +394,8 @@ def step( msg = _include_paths(msg) yield msg - # print response + # generate response and run tools + set_interruptible() try: # performs reduction/context trimming, if necessary msgs = log.prepare_messages() @@ -357,7 +411,10 @@ def step( yield msg_response.replace(quiet=True) yield from execute_msg(msg_response, ask=not no_confirm) except KeyboardInterrupt: + clear_interruptible() yield Message("system", "Interrupted") + finally: + clear_interruptible() def get_name(name: str) -> Path: