From 827adcc5d9430c1c133a9514fd24265f70ddb497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 10 Oct 2023 11:29:28 +0200 Subject: [PATCH] fix: further improvements to main loop --- gptme/cli.py | 123 ++++++++++++++++++++++++--------------------------- 1 file changed, 57 insertions(+), 66 deletions(-) diff --git a/gptme/cli.py b/gptme/cli.py index 6127741d..f1f2bef1 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -284,43 +284,37 @@ def main( logfile = get_logfile(name, interactive=not prompts and interactive) print(f"Using logdir {logfile.parent}") - logmanager = LogManager.load( - logfile, initial_msgs=promptmsgs, show_hidden=show_hidden - ) + log = LogManager.load(logfile, initial_msgs=promptmsgs, show_hidden=show_hidden) # print log - logmanager.print() + log.print() print("--- ^^^ past messages ^^^ ---") # check if any prompt is a full path, if so, replace it with the contents of that file + # TODO: add support for directories + # TODO: maybe do this for all prompts, not just those passed on cli prompts = [ f"```{p}\n{Path(p).expanduser().read_text()}\n```" - if Path(p).expanduser().exists() + if Path(p).expanduser().exists() and Path(p).expanduser().is_file() else p for p in prompts ] - cli_prompted = bool(prompts) + # join prompts, grouped by `-` if present, since that's the separator for multiple-round prompts + prompts = [p.strip() for p in "\n\n".join(prompts).split("\n\n-") if p] # main loop - ctx = loop(logmanager, no_confirm, model, llm) while True: - # if prompts given on cli: - # - insert prompt into logmanager - # - if a prompt is `-`, wait for reply before sending next prompt - # - set cli_prompted - while prompts: - if prompts[0] == "-": - prompts.pop(0) - break - logmanager.append(Message("user", prompts.pop(0))) + # if prompts given on cli, insert next prompt into log + if prompts: + prompt = prompts.pop(0) + log.append(Message("user", prompt)) - msg = next(ctx) - logmanager.append(msg) + for msg in loop(log, no_confirm, model, llm, stream=stream): + log.append(msg) - # if prompts have been ran and is non-interactive, exit - # this is used in testing - if cli_prompted and not prompts and not interactive: - logger.info("Command triggered and not in TTY, exiting") + # if non-interactive and prompts have been exhausted, exit + if not interactive and not prompts: + logger.info("Non-interactive and exhausted prompts, exiting") exit(0) @@ -331,54 +325,51 @@ def loop( llm: LLMChoice, stream: bool = True, ) -> Generator[Message, None, None]: + """Runs a single pass of the chat.""" + # if last message was from assistant, try to run tools again if log[-1].role == "assistant": yield from execute_msg(log[-1], ask=not no_confirm) - while True: - # execute user command - if log[-1].role == "user": - inquiry = log[-1].content - # if message starts with ., treat as command - # when command has been run, - if inquiry.startswith(".") or inquiry.startswith("$"): - yield from handle_cmd(inquiry, log, no_confirm=no_confirm) - # we need to re-assign `log` here since it may be replaced by `handle_cmd` - # FIXME: this is pretty bad hack to get things working, needs to be refactored - if inquiry != ".continue": - continue - - # If last message was a response, ask for input. - # If last message was from the user (such as from crash/edited log), - # then skip asking for input and generate response - last_msg = log[-1] if log else None - if not last_msg or ((last_msg.role in ["system", "assistant"])): - inquiry = prompt_user() - if not inquiry: - # Empty command, ask for input again - print() - continue - yield Message("user", inquiry, quiet=True) - - # print response - try: - # performs reduction/context trimming, if necessary - msgs = log.prepare_messages() - - # append temporary message with current context, right before user message - # NOTE: in my experience, this confused the model more than it helped - # msgs = msgs[:-1] + [_gen_context_msg()] + msgs[-1:] - - # generate response - msg_response = reply(msgs, model, stream) - - # log response and run tools - if msg_response: - msg_response.quiet = True - yield msg_response - yield from execute_msg(msg_response, ask=not no_confirm) - except KeyboardInterrupt: - yield Message("system", "Interrupted") + # execute user command + if log[-1].role == "user": + inquiry = log[-1].content + # if message starts with ., treat as command + # when command has been run, + if inquiry.startswith(".") or inquiry.startswith("$"): + yield from handle_cmd(inquiry, log, no_confirm=no_confirm) + # we need to re-assign `log` here since it may be replaced by `handle_cmd` + # FIXME: this is pretty bad hack to get things working, needs to be refactored + if inquiry != ".continue": + return + + # If last message was a response, ask for input. + # If last message was from the user (such as from crash/edited log), + # then skip asking for input and generate response + last_msg = log[-1] if log else None + if not last_msg or (last_msg.role in ["system", "assistant"]): + inquiry = prompt_user() + if not inquiry: + # Empty command, ask for input again + print() + return + yield Message("user", inquiry, quiet=True) + + # print response + try: + # performs reduction/context trimming, if necessary + msgs = log.prepare_messages() + + # generate response + msg_response = reply(msgs, model, stream) + + # log response and run tools + if msg_response: + msg_response.quiet = True + yield msg_response + yield from execute_msg(msg_response, ask=not no_confirm) + except KeyboardInterrupt: + yield Message("system", "Interrupted") def get_name(name: str) -> Path: