From d421cc8c58e69f745a7c5e76b20e8614a990f075 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Tue, 15 Oct 2024 10:57:24 +0200 Subject: [PATCH] refactor: work on programmatic interface, refactored LogManager into mutable manager and immutable Log dataclass, added wip treeofthought script --- gptme/chat.py | 46 +++++---- gptme/cli.py | 4 +- gptme/commands.py | 66 +++++++------ gptme/logmanager.py | 169 ++++++++++++++++++---------------- gptme/message.py | 5 + gptme/server/api.py | 16 ++-- gptme/tools/chats.py | 2 +- scripts/list_user_messages.py | 7 +- scripts/treeofthoughts.py | 93 +++++++++++++++++++ 9 files changed, 262 insertions(+), 146 deletions(-) create mode 100644 scripts/treeofthoughts.py diff --git a/gptme/chat.py b/gptme/chat.py index aa752886..e4008d26 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -15,7 +15,7 @@ from .init import init from .interrupt import clear_interruptible, set_interruptible from .llm import reply -from .logmanager import LogManager +from .logmanager import Log, LogManager, prepare_messages from .message import Message from .models import get_model from .tools import ToolUse, execute_msg, has_tool @@ -59,9 +59,7 @@ def chat( stream = False console.log(f"Using logdir {path_with_tilde(logdir)}") - log = LogManager.load( - logdir, initial_msgs=initial_msgs, show_hidden=show_hidden, create=True - ) + manager = LogManager.load(logdir, initial_msgs=initial_msgs, create=True) # change to workspace directory # use if exists, create if @log, or use given path @@ -82,13 +80,13 @@ def chat( # check if message is already in log, such as upon resume if ( workspace_prompt - and workspace_prompt not in [m.content for m in log] - and "user" not in [m.role for m in log] + and workspace_prompt not in [m.content for m in manager.log] + and "user" not in [m.role for m in manager.log] ): - log.append(Message("system", workspace_prompt, hide=True, quiet=True)) + manager.append(Message("system", workspace_prompt, hide=True, quiet=True)) # print log - log.print() + manager.log.print(show_hidden=show_hidden) console.print("--- ^^^ past messages ^^^ ---") # main loop @@ -99,34 +97,39 @@ def chat( msg = prompt_msgs.pop(0) if not msg.content.startswith("/"): msg = _include_paths(msg) - log.append(msg) + manager.append(msg) # if prompt is a user-command, execute it - if execute_cmd(msg, log): + if execute_cmd(msg, manager): continue # Generate and execute response for this prompt while True: set_interruptible() try: - response_msgs = list(step(log, no_confirm, stream=stream)) + response_msgs = list(step(manager, no_confirm, stream=stream)) except KeyboardInterrupt: console.log("Interrupted. Stopping current execution.") - log.append(Message("system", "Interrupted")) + manager.append(Message("system", "Interrupted")) break finally: clear_interruptible() for response_msg in response_msgs: - log.append(response_msg) + manager.append(response_msg) # run any user-commands, if msg is from user if response_msg.role == "user" and execute_cmd( - response_msg, log + response_msg, manager ): break # Check if there are any runnable tools left last_content = next( - (m.content for m in reversed(log) if m.role == "assistant"), "" + ( + m.content + for m in reversed(manager.log) + if m.role == "assistant" + ), + "", ) if not any( tooluse.is_runnable @@ -148,19 +151,22 @@ def chat( # ask for input if no prompt, generate reply, and run tools clear_interruptible() # Ensure we're not interruptible during user input - for msg in step(log, no_confirm, stream=stream): # pragma: no cover - log.append(msg) + for msg in step(manager, no_confirm, stream=stream): # pragma: no cover + manager.append(msg) # run any user-commands, if msg is from user - if msg.role == "user" and execute_cmd(msg, log): + if msg.role == "user" and execute_cmd(msg, manager): break def step( - log: LogManager, + log: Log | LogManager, no_confirm: bool, stream: bool = True, ) -> Generator[Message, None, None]: """Runs a single pass of the chat.""" + if isinstance(log, LogManager): + log = log.log + # If last message was a response, ask for input. # If last message was from the user (such as from crash/edited log), # then skip asking for input and generate response @@ -184,7 +190,7 @@ def step( set_interruptible() try: # performs reduction/context trimming, if necessary - msgs = log.prepare_messages() + msgs = prepare_messages(log.messages) for m in msgs: logger.debug(f"Prepared message: {m}") diff --git a/gptme/cli.py b/gptme/cli.py index 38b253cf..6cb74a9d 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -17,7 +17,7 @@ from .dirs import get_logs_dir from .init import init_logging from .interrupt import handle_keyboard_interrupt -from .logmanager import Conversation, get_user_conversations +from .logmanager import ConversationMeta, get_user_conversations from .message import Message from .prompts import get_prompt from .tools import all_tools, init_tools @@ -273,7 +273,7 @@ def pick_log(limit=20) -> Path: # pragma: no cover NEW_CONV = "New conversation" LOAD_MORE = "Load more" gen_convs = get_user_conversations() - convs: list[Conversation] = [] + convs: list[ConversationMeta] = [] # load conversations convs.extend(islice(gen_convs, limit)) diff --git a/gptme/commands.py b/gptme/commands.py index 47b7fcd1..b25f3df1 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -6,7 +6,7 @@ from typing import Literal from . import llm -from .logmanager import LogManager +from .logmanager import LogManager, prepare_messages from .message import ( Message, len_tokens, @@ -68,7 +68,7 @@ def execute_cmd(msg: Message, log: LogManager) -> bool: def handle_cmd( - cmd: str, log: LogManager, no_confirm: bool + cmd: str, manager: LogManager, no_confirm: bool ) -> Generator[Message, None, None]: """Handles a command.""" cmd = cmd.lstrip("/") @@ -77,44 +77,44 @@ def handle_cmd( full_args = cmd.split(" ", 1)[1] if " " in cmd else "" match name: case "log": - log.undo(1, quiet=True) - log.print(show_hidden="--hidden" in args) + manager.undo(1, quiet=True) + manager.log.print(show_hidden="--hidden" in args) case "rename": - log.undo(1, quiet=True) - log.write() + manager.undo(1, quiet=True) + manager.write() # rename the conversation print("Renaming conversation (enter empty name to auto-generate)") new_name = args[0] if args else input("New name: ") - rename(log, new_name, ask=not no_confirm) + rename(manager, new_name, ask=not no_confirm) case "fork": # fork the conversation new_name = args[0] if args else input("New name: ") - log.fork(new_name) + manager.fork(new_name) case "summarize": - msgs = log.prepare_messages() + msgs = prepare_messages(manager.log.messages) msgs = [m for m in msgs if not m.hide] summary = llm.summarize(msgs) print(f"Summary: {summary}") case "edit": # edit previous messages # first undo the '/edit' command itself - log.undo(1, quiet=True) - yield from edit(log) + manager.undo(1, quiet=True) + yield from edit(manager) case "undo": # undo the '/undo' command itself - log.undo(1, quiet=True) + manager.undo(1, quiet=True) # if int, undo n messages n = int(args[0]) if args and args[0].isdigit() else 1 - log.undo(n) + manager.undo(n) case "exit": - log.undo(1, quiet=True) - log.write() + manager.undo(1, quiet=True) + manager.write() sys.exit(0) case "replay": - log.undo(1, quiet=True) - log.write() + manager.undo(1, quiet=True) + manager.write() print("Replaying conversation...") - for msg in log.log: + for msg in manager.log: if msg.role == "assistant": for reply_msg in execute_msg(msg, ask=True): print_msg(reply_msg, oneline=False) @@ -124,8 +124,8 @@ def handle_cmd( yield msg yield from execute_msg(msg, ask=not no_confirm) case "tokens": - log.undo(1, quiet=True) - n_tokens = len_tokens(log.log) + manager.undo(1, quiet=True) + n_tokens = len_tokens(manager.log.messages) print(f"Tokens used: {n_tokens}") model = get_model() if model: @@ -133,7 +133,7 @@ def handle_cmd( if model.price_input: print(f"Cost (input): ${n_tokens * model.price_input / 1_000_000}") case "tools": - log.undo(1, quiet=True) + manager.undo(1, quiet=True) print("Available tools:") for tool in loaded_tools: print( @@ -148,18 +148,18 @@ def handle_cmd( if tooluse.is_runnable: yield from tooluse.execute(ask=not no_confirm) else: - if log.log[-1].content.strip() == "/help": + if manager.log[-1].content.strip() == "/help": # undo the '/help' command itself - log.undo(1, quiet=True) - log.write() + manager.undo(1, quiet=True) + manager.write() help() else: print("Unknown command") -def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover +def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no cover # generate editable toml of all messages - t = msgs_to_toml(reversed(log.log)) # type: ignore + t = msgs_to_toml(reversed(manager.log)) # type: ignore res = None while not res: t = edit_text_with_editor(t, "toml") @@ -172,15 +172,13 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover except KeyboardInterrupt: yield Message("system", "Interrupted") return - log.edit(list(reversed(res))) - # now we need to redraw the log so the user isn't seeing stale messages in their buffer - # log.print() + manager.edit(list(reversed(res))) print("Applied edited messages, write /log to see the result") -def rename(log: LogManager, new_name: str, ask: bool = True): +def rename(manager: LogManager, new_name: str, ask: bool = True): if new_name in ["", "auto"]: - new_name = llm.generate_name(log.prepare_messages()) + new_name = llm.generate_name(prepare_messages(manager.log.messages)) assert " " not in new_name print(f"Generated name: {new_name}") if ask: @@ -188,10 +186,10 @@ def rename(log: LogManager, new_name: str, ask: bool = True): if not confirm: print("Aborting") return - log.rename(new_name, keep_date=True) + manager.rename(new_name, keep_date=True) else: - log.rename(new_name, keep_date=False) - print(f"Renamed conversation to {log.logfile.parent}") + manager.rename(new_name, keep_date=False) + print(f"Renamed conversation to {manager.logfile.parent}") def _gen_help(incl_langtags: bool = True) -> Generator[str, None, None]: diff --git a/gptme/logmanager.py b/gptme/logmanager.py index 770225b6..e2455c2a 100644 --- a/gptme/logmanager.py +++ b/gptme/logmanager.py @@ -3,8 +3,7 @@ import shutil import textwrap from collections.abc import Generator -from copy import copy -from dataclasses import dataclass +from dataclasses import dataclass, field, replace from datetime import datetime from itertools import islice, zip_longest from pathlib import Path @@ -25,6 +24,44 @@ RoleLiteral = Literal["user", "assistant", "system"] +@dataclass(frozen=True) +class Log: + messages: list[Message] = field(default_factory=list) + + def __getitem__(self, key): + return self.messages[key] + + def __len__(self) -> int: + return len(self.messages) + + def __iter__(self) -> Generator[Message, None, None]: + yield from self.messages + + def replace(self, **kwargs) -> "Log": + return replace(self, **kwargs) + + def append(self, msg: Message) -> "Log": + return self.replace(messages=self.messages + [msg]) + + def pop(self) -> "Log": + return self.replace(messages=self.messages[:-1]) + + @classmethod + def read_jsonl(cls, path: PathLike, limit=None) -> "Log": + gen = _gen_read_jsonl(path) + if limit: + gen = islice(gen, limit) # type: ignore + return Log(list(gen)) + + def write_jsonl(self, path: PathLike) -> None: + with open(path, "w") as file: + for msg in self.messages: + file.write(json.dumps(msg.to_dict()) + "\n") + + def print(self, show_hidden: bool = False): + print_msg(self.messages, oneline=False, show_hidden=show_hidden) + + class LogManager: """Manages a conversation log.""" @@ -33,7 +70,6 @@ def __init__( log: list[Message] | None = None, logdir: PathLike | None = None, branch: str | None = None, - show_hidden=False, ): self.current_branch = branch or "main" @@ -47,11 +83,11 @@ def __init__( self.name = self.logdir.name # load branches from adjacent files - self._branches = {self.current_branch: log or []} + self._branches = {self.current_branch: Log(log or [])} if self.logdir / "conversation.jsonl": _branch = "main" if _branch not in self._branches: - self._branches[_branch] = _read_jsonl( + self._branches[_branch] = Log.read_jsonl( self.logdir / "conversation.jsonl" ) for file in self.logdir.glob("branches/*.jsonl"): @@ -59,36 +95,29 @@ def __init__( continue _branch = file.stem if _branch not in self._branches: - self._branches[_branch] = _read_jsonl(file) + self._branches[_branch] = Log.read_jsonl(file) - self.show_hidden = show_hidden # TODO: Check if logfile has contents, then maybe load, or should it overwrite? @property - def log(self) -> list[Message]: + def log(self) -> Log: return self._branches[self.current_branch] + @log.setter + def log(self, value: Log | list[Message]) -> None: + if isinstance(value, list): + value = Log(value) + self._branches[self.current_branch] = value + @property def logfile(self) -> Path: if self.current_branch == "main": return get_logs_dir() / self.name / "conversation.jsonl" return self.logdir / "branches" / f"{self.current_branch}.jsonl" - def __getitem__(self, key): - return self.log[key] - - def __len__(self): - return len(self.log) - - def __iter__(self): - return iter(self.log) - - def __bool__(self): - return bool(self.log) - def append(self, msg: Message) -> None: """Appends a message to the log, writes the log, prints the message.""" - self.log.append(msg) + self.log = self.log.append(msg) self.write() if not msg.quiet: print_msg(msg, oneline=False) @@ -101,48 +130,47 @@ def write(self, branches=True) -> None: Path(self.logfile).parent.mkdir(parents=True, exist_ok=True) # write current branch - _write_jsonl(self.logfile, self.log) + self.log.write_jsonl(self.logfile) # write other branches # FIXME: wont write main branch if on a different branch if branches: branches_dir = self.logdir / "branches" branches_dir.mkdir(parents=True, exist_ok=True) - for branch, msgs in self._branches.items(): + for branch, log in self._branches.items(): if branch == "main": continue branch_path = branches_dir / f"{branch}.jsonl" - _write_jsonl(branch_path, msgs) - - def print(self, show_hidden: bool | None = None): - print_msg(self.log, oneline=False, show_hidden=show_hidden or self.show_hidden) + log.write_jsonl(branch_path) def _save_backup_branch(self, type="edit") -> None: """backup the current log to a new branch, usually before editing/undoing""" branch_prefix = f"{self.current_branch}-{type}-" n = len([b for b in self._branches.keys() if b.startswith(branch_prefix)]) - self._branches[f"{branch_prefix}{n}"] = copy(self.log) + self._branches[f"{branch_prefix}{n}"] = self.log self.write() - def edit(self, new_log: list[Message]) -> None: + def edit(self, new_log: Log | list[Message]) -> None: """Edits the log.""" + if isinstance(new_log, list): + new_log = Log(new_log) self._save_backup_branch(type="edit") - self._branches[self.current_branch] = new_log + self.log = new_log self.write() def undo(self, n: int = 1, quiet=False) -> None: """Removes the last message from the log.""" - undid = self[-1] if self.log else None + undid = self.log[-1] if self.log else None if undid and undid.content.startswith("/undo"): - self.log.pop() + self.log = self.log.pop() # don't save backup branch if undoing a command - if not self[-1].content.startswith("/"): + if self.log and not self.log[-1].content.startswith("/"): self._save_backup_branch(type="undo") # Doesn't work for multiple undos in a row, but useful in testing # assert undid.content == ".undo" # assert that the last message is an undo - peek = self[-1] if self.log else None + peek = self.log[-1] if self.log else None if not peek: print("[yellow]Nothing to undo.[/]") return @@ -150,28 +178,13 @@ def undo(self, n: int = 1, quiet=False) -> None: if not quiet: print("[yellow]Undoing messages:[/yellow]") for _ in range(n): - undid = self.log.pop() + undid = self.log[-1] + self.log = self.log.pop() if not quiet: print( f"[red] {undid.role}: {textwrap.shorten(undid.content.strip(), width=50, placeholder='...')}[/]", ) - peek = self[-1] if self.log else None - - def prepare_messages(self) -> list[Message]: - """Prepares the log into messages before sending it to the LLM.""" - msgs = self.log - msgs_reduced = list(reduce_log(msgs)) - - if len_tokens(msgs) != len_tokens(msgs_reduced): - logger.info( - f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens" - ) - msgs_limited = limit_log(msgs_reduced) - if len(msgs_reduced) != len(msgs_limited): - logger.info( - f"Limited log from {len(msgs_reduced)} to {len(msgs_limited)} messages" - ) - return msgs_limited + peek = self.log[-1] if self.log else None @classmethod def load( @@ -202,13 +215,12 @@ def load( if create: logger.debug(f"Creating new logfile {logfile}") Path(logfile).parent.mkdir(parents=True, exist_ok=True) - _write_jsonl(logfile, []) + Log([]).write_jsonl(logfile) else: raise FileNotFoundError(f"Could not find logfile {logfile}") - msgs = _read_jsonl(logfile) - if not msgs: - msgs = initial_msgs or [get_prompt()] + log = Log.read_jsonl(logfile) + msgs = log.messages or initial_msgs or [get_prompt()] return cls(msgs, logdir=logdir, branch=branch, **kwargs) def branch(self, name: str) -> None: @@ -216,7 +228,7 @@ def branch(self, name: str) -> None: self.write() if name not in self._branches: logger.info(f"Creating a new branch '{name}'") - self._branches[name] = copy(self.log) + self._branches[name] = self.log self.current_branch = name def diff(self, branch: str) -> str | None: @@ -292,6 +304,22 @@ def to_dict(self, branches=False) -> dict: return d +def prepare_messages(msgs: list[Message]) -> list[Message]: + """Prepares the messages before sending to the LLM.""" + msgs_reduced = list(reduce_log(msgs)) + + if len_tokens(msgs) != len_tokens(msgs_reduced): + logger.info( + f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens" + ) + msgs_limited = limit_log(msgs_reduced) + if len(msgs_reduced) != len(msgs_limited): + logger.info( + f"Limited log from {len(msgs_reduced)} to {len(msgs_limited)} messages" + ) + return msgs_limited + + def _conversation_files() -> list[Path]: # NOTE: only returns the main conversation, not branches (to avoid duplicates) # returns the conversation files sorted by modified time (newest first) @@ -302,7 +330,7 @@ def _conversation_files() -> list[Path]: @dataclass(frozen=True) -class Conversation: +class ConversationMeta: name: str path: str created: float @@ -311,16 +339,16 @@ class Conversation: branches: int -def get_conversations() -> Generator[Conversation, None, None]: +def get_conversations() -> Generator[ConversationMeta, None, None]: """Returns all conversations, excluding ones used for testing, evals, etc.""" for conv_fn in _conversation_files(): - msgs = _read_jsonl(conv_fn, limit=1) + log = Log.read_jsonl(conv_fn, limit=1) # TODO: can we avoid reading the entire file? maybe wont even be used, due to user convo filtering len_msgs = conv_fn.read_text().count("}\n{") - assert len(msgs) <= 1 + assert len(log) <= 1 modified = conv_fn.stat().st_mtime - first_timestamp = msgs[0].timestamp.timestamp() if msgs else modified - yield Conversation( + first_timestamp = log[0].timestamp.timestamp() if log else modified + yield ConversationMeta( name=f"{conv_fn.parent.name}", path=str(conv_fn), created=first_timestamp, @@ -330,7 +358,7 @@ def get_conversations() -> Generator[Conversation, None, None]: ) -def get_user_conversations() -> Generator[Conversation, None, None]: +def get_user_conversations() -> Generator[ConversationMeta, None, None]: """Returns all user conversations, excluding ones used for testing, evals, etc.""" for conv in get_conversations(): if any(conv.name.startswith(prefix) for prefix in ["tmp", "test-"]) or any( @@ -348,16 +376,3 @@ def _gen_read_jsonl(path: PathLike) -> Generator[Message, None, None]: if "timestamp" in json_data: json_data["timestamp"] = datetime.fromisoformat(json_data["timestamp"]) yield Message(**json_data, files=files) - - -def _read_jsonl(path: PathLike, limit=None) -> list[Message]: - gen = _gen_read_jsonl(path) - if limit: - gen = islice(gen, limit) # type: ignore - return list(gen) - - -def _write_jsonl(path: PathLike, msgs: list[Message]) -> None: - with open(path, "w") as file: - for msg in msgs: - file.write(json.dumps(msg.to_dict()) + "\n") diff --git a/gptme/message.py b/gptme/message.py index 2b6f98cd..1246d9fb 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -172,6 +172,11 @@ def to_dict(self, keys=None, openai=False, anthropic=False) -> dict: return {k: d[k] for k in keys} return d + def to_xml(self) -> str: + """Converts a message to an XML string.""" + attrs = f"role='{self.role}'" + return f"\n{self.content}\n" + def format(self, oneline: bool = False, highlight: bool = False) -> str: return format_msgs([self], oneline=oneline, highlight=highlight)[0] diff --git a/gptme/server/api.py b/gptme/server/api.py index 4061f51a..e0c6450c 100644 --- a/gptme/server/api.py +++ b/gptme/server/api.py @@ -18,7 +18,7 @@ from ..commands import execute_cmd from ..dirs import get_logs_dir from ..llm import reply -from ..logmanager import LogManager, get_user_conversations +from ..logmanager import LogManager, get_user_conversations, prepare_messages from ..message import Message from ..models import get_model from ..tools import execute_msg @@ -91,26 +91,26 @@ def api_conversation_generate(logfile: str): model = req_json.get("model", get_model().model) # load conversation - log = LogManager.load(logfile, branch=req_json.get("branch", "main")) + manager = LogManager.load(logfile, branch=req_json.get("branch", "main")) # if prompt is a user-command, execute it - if log[-1].role == "user": + if manager.log[-1].role == "user": # TODO: capture output of command and return it f = io.StringIO() print("Begin capturing stdout, to pass along command output.") with redirect_stdout(f): - resp = execute_cmd(log[-1], log) + resp = execute_cmd(manager.log[-1], manager) print("Done capturing stdout.") if resp: - log.write() + manager.write() output = f.getvalue() return flask.jsonify( [{"role": "system", "content": output, "stored": False}] ) # performs reduction/context trimming, if necessary - msgs = log.prepare_messages() + msgs = prepare_messages(manager.log.messages) # generate response # TODO: add support for streaming @@ -119,10 +119,10 @@ def api_conversation_generate(logfile: str): # log response and run tools resp_msgs = [] - log.append(msg) + manager.append(msg) resp_msgs.append(msg) for reply_msg in execute_msg(msg, ask=False): - log.append(reply_msg) + manager.append(reply_msg) resp_msgs.append(reply_msg) return flask.jsonify( diff --git a/gptme/tools/chats.py b/gptme/tools/chats.py index 991b3f02..a009fda5 100644 --- a/gptme/tools/chats.py +++ b/gptme/tools/chats.py @@ -47,7 +47,7 @@ def _summarize_conversation( summary_lines = [] if include_summary: - summary = llm_summarize(log_manager.log) + summary = llm_summarize(log_manager.log.messages) summary_lines.append(indent(f"Summary: {summary.content}", " ")) else: non_system_messages = [msg for msg in log_manager.log if msg.role != "system"] diff --git a/scripts/list_user_messages.py b/scripts/list_user_messages.py index 66a49c38..4459b2fb 100644 --- a/scripts/list_user_messages.py +++ b/scripts/list_user_messages.py @@ -1,21 +1,20 @@ import logging from datetime import datetime -from gptme.logmanager import Conversation, _read_jsonl, get_user_conversations +from gptme.logmanager import ConversationMeta, Log, get_user_conversations # Set up logging logging.basicConfig(level=logging.ERROR) -def print_user_messages(conv: Conversation): +def print_user_messages(conv: ConversationMeta): """ Print all user messages from a single conversation. :param conversation: A dictionary containing conversation details """ lines = [] - msgs = _read_jsonl(conv.path) - for message in msgs: + for message in Log.read_jsonl(conv.path): if message.role == "user": first_line = message.content.split("\n")[0] if first_line.startswith(""): diff --git a/scripts/treeofthoughts.py b/scripts/treeofthoughts.py new file mode 100644 index 00000000..fdb01233 --- /dev/null +++ b/scripts/treeofthoughts.py @@ -0,0 +1,93 @@ +""" +Tree-branching conversations for gptme with branch evaluation/prediction. + +The idea is to evaluate if we are on the right track by checking if the current branch is "good"/making progress, and otherwise backtracking to the last good branch and trying a different prompt/approach. +""" + +import sys +from typing import Literal + +from gptme.chat import step as _step +from gptme.init import init +from gptme.logmanager import Log +from gptme.message import Message +from gptme.prompts import get_prompt +from lxml import etree + +EvalAction = Literal["continue", "undo", "done"] + + +def step(log: Log) -> Log: + # Steps the conversation forward + for msg in _step(log, no_confirm=True): + log = log.append(msg) + return log + + +def recommendation(log: Log) -> EvalAction: + # Returns a LLM-guided recommendation for the next action + # Can be: undo (backtrack), restart, continue, + system_msg = Message( + "system", + """ +We are evaluating the agent in the following conversation to determine the next action. + +Please evaluate the current state of the conversation, +if the agent is making progress or if we should undo, +and then recommend the next action within tags. + +For example: +continue to let the agent continue (making progress) +undo to backtrack until last user prompt (made a mistake) +done if the agent has completed the task (e.g. answered the question) +""", + ) + log_xml = "Here is the conversation to evaluate:\n" + for msg in log: + log_xml += msg.to_xml() + "\n" + log = Log( + [system_msg] + + [Message("system", log_xml)] + + [Message("user", "evaluate the agent")] + ) + log = step(log) # TODO: use faster model for this + parser = etree.HTMLParser() + tree = etree.fromstring(log[-1].content, parser) + return tree.xpath("//action")[0].text + + +print("Init...") +init( + model="openai/gpt-4o", + interactive=False, + tool_allowlist=["python", "shell", "save", "patch"], +) + +# Set up the conversation +prompt = sys.argv[1] if len(sys.argv) > 1 else "What is fib(10)?" +prompts = [Message("user", prompt)] +initial_msgs = [get_prompt("full", interactive=False)] +log = Log(initial_msgs + prompts) + +while True: + # Step it forward + print("Stepping...") + log = step(log) + print("Done with step") + + # Evaluate the conversation + action = recommendation(log) + print(f"Recommendation: {action}") + + # Take the recommended action + if action == "continue": + continue + elif action == "undo": + log = log.pop() + elif action == "done": + break + else: + raise ValueError(f"Invalid action: {action}") + +# Print the final conversation +log.print()