diff --git a/gptme/chat.py b/gptme/chat.py
index aa752886..e4008d26 100644
--- a/gptme/chat.py
+++ b/gptme/chat.py
@@ -15,7 +15,7 @@
from .init import init
from .interrupt import clear_interruptible, set_interruptible
from .llm import reply
-from .logmanager import LogManager
+from .logmanager import Log, LogManager, prepare_messages
from .message import Message
from .models import get_model
from .tools import ToolUse, execute_msg, has_tool
@@ -59,9 +59,7 @@ def chat(
stream = False
console.log(f"Using logdir {path_with_tilde(logdir)}")
- log = LogManager.load(
- logdir, initial_msgs=initial_msgs, show_hidden=show_hidden, create=True
- )
+ manager = LogManager.load(logdir, initial_msgs=initial_msgs, create=True)
# change to workspace directory
# use if exists, create if @log, or use given path
@@ -82,13 +80,13 @@ def chat(
# check if message is already in log, such as upon resume
if (
workspace_prompt
- and workspace_prompt not in [m.content for m in log]
- and "user" not in [m.role for m in log]
+ and workspace_prompt not in [m.content for m in manager.log]
+ and "user" not in [m.role for m in manager.log]
):
- log.append(Message("system", workspace_prompt, hide=True, quiet=True))
+ manager.append(Message("system", workspace_prompt, hide=True, quiet=True))
# print log
- log.print()
+ manager.log.print(show_hidden=show_hidden)
console.print("--- ^^^ past messages ^^^ ---")
# main loop
@@ -99,34 +97,39 @@ def chat(
msg = prompt_msgs.pop(0)
if not msg.content.startswith("/"):
msg = _include_paths(msg)
- log.append(msg)
+ manager.append(msg)
# if prompt is a user-command, execute it
- if execute_cmd(msg, log):
+ if execute_cmd(msg, manager):
continue
# Generate and execute response for this prompt
while True:
set_interruptible()
try:
- response_msgs = list(step(log, no_confirm, stream=stream))
+ response_msgs = list(step(manager, no_confirm, stream=stream))
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
- log.append(Message("system", "Interrupted"))
+ manager.append(Message("system", "Interrupted"))
break
finally:
clear_interruptible()
for response_msg in response_msgs:
- log.append(response_msg)
+ manager.append(response_msg)
# run any user-commands, if msg is from user
if response_msg.role == "user" and execute_cmd(
- response_msg, log
+ response_msg, manager
):
break
# Check if there are any runnable tools left
last_content = next(
- (m.content for m in reversed(log) if m.role == "assistant"), ""
+ (
+ m.content
+ for m in reversed(manager.log)
+ if m.role == "assistant"
+ ),
+ "",
)
if not any(
tooluse.is_runnable
@@ -148,19 +151,22 @@ def chat(
# ask for input if no prompt, generate reply, and run tools
clear_interruptible() # Ensure we're not interruptible during user input
- for msg in step(log, no_confirm, stream=stream): # pragma: no cover
- log.append(msg)
+ for msg in step(manager, no_confirm, stream=stream): # pragma: no cover
+ manager.append(msg)
# run any user-commands, if msg is from user
- if msg.role == "user" and execute_cmd(msg, log):
+ if msg.role == "user" and execute_cmd(msg, manager):
break
def step(
- log: LogManager,
+ log: Log | LogManager,
no_confirm: bool,
stream: bool = True,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""
+ if isinstance(log, LogManager):
+ log = log.log
+
# If last message was a response, ask for input.
# If last message was from the user (such as from crash/edited log),
# then skip asking for input and generate response
@@ -184,7 +190,7 @@ def step(
set_interruptible()
try:
# performs reduction/context trimming, if necessary
- msgs = log.prepare_messages()
+ msgs = prepare_messages(log.messages)
for m in msgs:
logger.debug(f"Prepared message: {m}")
diff --git a/gptme/cli.py b/gptme/cli.py
index 38b253cf..6cb74a9d 100644
--- a/gptme/cli.py
+++ b/gptme/cli.py
@@ -17,7 +17,7 @@
from .dirs import get_logs_dir
from .init import init_logging
from .interrupt import handle_keyboard_interrupt
-from .logmanager import Conversation, get_user_conversations
+from .logmanager import ConversationMeta, get_user_conversations
from .message import Message
from .prompts import get_prompt
from .tools import all_tools, init_tools
@@ -273,7 +273,7 @@ def pick_log(limit=20) -> Path: # pragma: no cover
NEW_CONV = "New conversation"
LOAD_MORE = "Load more"
gen_convs = get_user_conversations()
- convs: list[Conversation] = []
+ convs: list[ConversationMeta] = []
# load conversations
convs.extend(islice(gen_convs, limit))
diff --git a/gptme/commands.py b/gptme/commands.py
index 47b7fcd1..b25f3df1 100644
--- a/gptme/commands.py
+++ b/gptme/commands.py
@@ -6,7 +6,7 @@
from typing import Literal
from . import llm
-from .logmanager import LogManager
+from .logmanager import LogManager, prepare_messages
from .message import (
Message,
len_tokens,
@@ -68,7 +68,7 @@ def execute_cmd(msg: Message, log: LogManager) -> bool:
def handle_cmd(
- cmd: str, log: LogManager, no_confirm: bool
+ cmd: str, manager: LogManager, no_confirm: bool
) -> Generator[Message, None, None]:
"""Handles a command."""
cmd = cmd.lstrip("/")
@@ -77,44 +77,44 @@ def handle_cmd(
full_args = cmd.split(" ", 1)[1] if " " in cmd else ""
match name:
case "log":
- log.undo(1, quiet=True)
- log.print(show_hidden="--hidden" in args)
+ manager.undo(1, quiet=True)
+ manager.log.print(show_hidden="--hidden" in args)
case "rename":
- log.undo(1, quiet=True)
- log.write()
+ manager.undo(1, quiet=True)
+ manager.write()
# rename the conversation
print("Renaming conversation (enter empty name to auto-generate)")
new_name = args[0] if args else input("New name: ")
- rename(log, new_name, ask=not no_confirm)
+ rename(manager, new_name, ask=not no_confirm)
case "fork":
# fork the conversation
new_name = args[0] if args else input("New name: ")
- log.fork(new_name)
+ manager.fork(new_name)
case "summarize":
- msgs = log.prepare_messages()
+ msgs = prepare_messages(manager.log.messages)
msgs = [m for m in msgs if not m.hide]
summary = llm.summarize(msgs)
print(f"Summary: {summary}")
case "edit":
# edit previous messages
# first undo the '/edit' command itself
- log.undo(1, quiet=True)
- yield from edit(log)
+ manager.undo(1, quiet=True)
+ yield from edit(manager)
case "undo":
# undo the '/undo' command itself
- log.undo(1, quiet=True)
+ manager.undo(1, quiet=True)
# if int, undo n messages
n = int(args[0]) if args and args[0].isdigit() else 1
- log.undo(n)
+ manager.undo(n)
case "exit":
- log.undo(1, quiet=True)
- log.write()
+ manager.undo(1, quiet=True)
+ manager.write()
sys.exit(0)
case "replay":
- log.undo(1, quiet=True)
- log.write()
+ manager.undo(1, quiet=True)
+ manager.write()
print("Replaying conversation...")
- for msg in log.log:
+ for msg in manager.log:
if msg.role == "assistant":
for reply_msg in execute_msg(msg, ask=True):
print_msg(reply_msg, oneline=False)
@@ -124,8 +124,8 @@ def handle_cmd(
yield msg
yield from execute_msg(msg, ask=not no_confirm)
case "tokens":
- log.undo(1, quiet=True)
- n_tokens = len_tokens(log.log)
+ manager.undo(1, quiet=True)
+ n_tokens = len_tokens(manager.log.messages)
print(f"Tokens used: {n_tokens}")
model = get_model()
if model:
@@ -133,7 +133,7 @@ def handle_cmd(
if model.price_input:
print(f"Cost (input): ${n_tokens * model.price_input / 1_000_000}")
case "tools":
- log.undo(1, quiet=True)
+ manager.undo(1, quiet=True)
print("Available tools:")
for tool in loaded_tools:
print(
@@ -148,18 +148,18 @@ def handle_cmd(
if tooluse.is_runnable:
yield from tooluse.execute(ask=not no_confirm)
else:
- if log.log[-1].content.strip() == "/help":
+ if manager.log[-1].content.strip() == "/help":
# undo the '/help' command itself
- log.undo(1, quiet=True)
- log.write()
+ manager.undo(1, quiet=True)
+ manager.write()
help()
else:
print("Unknown command")
-def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover
+def edit(manager: LogManager) -> Generator[Message, None, None]: # pragma: no cover
# generate editable toml of all messages
- t = msgs_to_toml(reversed(log.log)) # type: ignore
+ t = msgs_to_toml(reversed(manager.log)) # type: ignore
res = None
while not res:
t = edit_text_with_editor(t, "toml")
@@ -172,15 +172,13 @@ def edit(log: LogManager) -> Generator[Message, None, None]: # pragma: no cover
except KeyboardInterrupt:
yield Message("system", "Interrupted")
return
- log.edit(list(reversed(res)))
- # now we need to redraw the log so the user isn't seeing stale messages in their buffer
- # log.print()
+ manager.edit(list(reversed(res)))
print("Applied edited messages, write /log to see the result")
-def rename(log: LogManager, new_name: str, ask: bool = True):
+def rename(manager: LogManager, new_name: str, ask: bool = True):
if new_name in ["", "auto"]:
- new_name = llm.generate_name(log.prepare_messages())
+ new_name = llm.generate_name(prepare_messages(manager.log.messages))
assert " " not in new_name
print(f"Generated name: {new_name}")
if ask:
@@ -188,10 +186,10 @@ def rename(log: LogManager, new_name: str, ask: bool = True):
if not confirm:
print("Aborting")
return
- log.rename(new_name, keep_date=True)
+ manager.rename(new_name, keep_date=True)
else:
- log.rename(new_name, keep_date=False)
- print(f"Renamed conversation to {log.logfile.parent}")
+ manager.rename(new_name, keep_date=False)
+ print(f"Renamed conversation to {manager.logfile.parent}")
def _gen_help(incl_langtags: bool = True) -> Generator[str, None, None]:
diff --git a/gptme/logmanager.py b/gptme/logmanager.py
index 770225b6..e2455c2a 100644
--- a/gptme/logmanager.py
+++ b/gptme/logmanager.py
@@ -3,8 +3,7 @@
import shutil
import textwrap
from collections.abc import Generator
-from copy import copy
-from dataclasses import dataclass
+from dataclasses import dataclass, field, replace
from datetime import datetime
from itertools import islice, zip_longest
from pathlib import Path
@@ -25,6 +24,44 @@
RoleLiteral = Literal["user", "assistant", "system"]
+@dataclass(frozen=True)
+class Log:
+ messages: list[Message] = field(default_factory=list)
+
+ def __getitem__(self, key):
+ return self.messages[key]
+
+ def __len__(self) -> int:
+ return len(self.messages)
+
+ def __iter__(self) -> Generator[Message, None, None]:
+ yield from self.messages
+
+ def replace(self, **kwargs) -> "Log":
+ return replace(self, **kwargs)
+
+ def append(self, msg: Message) -> "Log":
+ return self.replace(messages=self.messages + [msg])
+
+ def pop(self) -> "Log":
+ return self.replace(messages=self.messages[:-1])
+
+ @classmethod
+ def read_jsonl(cls, path: PathLike, limit=None) -> "Log":
+ gen = _gen_read_jsonl(path)
+ if limit:
+ gen = islice(gen, limit) # type: ignore
+ return Log(list(gen))
+
+ def write_jsonl(self, path: PathLike) -> None:
+ with open(path, "w") as file:
+ for msg in self.messages:
+ file.write(json.dumps(msg.to_dict()) + "\n")
+
+ def print(self, show_hidden: bool = False):
+ print_msg(self.messages, oneline=False, show_hidden=show_hidden)
+
+
class LogManager:
"""Manages a conversation log."""
@@ -33,7 +70,6 @@ def __init__(
log: list[Message] | None = None,
logdir: PathLike | None = None,
branch: str | None = None,
- show_hidden=False,
):
self.current_branch = branch or "main"
@@ -47,11 +83,11 @@ def __init__(
self.name = self.logdir.name
# load branches from adjacent files
- self._branches = {self.current_branch: log or []}
+ self._branches = {self.current_branch: Log(log or [])}
if self.logdir / "conversation.jsonl":
_branch = "main"
if _branch not in self._branches:
- self._branches[_branch] = _read_jsonl(
+ self._branches[_branch] = Log.read_jsonl(
self.logdir / "conversation.jsonl"
)
for file in self.logdir.glob("branches/*.jsonl"):
@@ -59,36 +95,29 @@ def __init__(
continue
_branch = file.stem
if _branch not in self._branches:
- self._branches[_branch] = _read_jsonl(file)
+ self._branches[_branch] = Log.read_jsonl(file)
- self.show_hidden = show_hidden
# TODO: Check if logfile has contents, then maybe load, or should it overwrite?
@property
- def log(self) -> list[Message]:
+ def log(self) -> Log:
return self._branches[self.current_branch]
+ @log.setter
+ def log(self, value: Log | list[Message]) -> None:
+ if isinstance(value, list):
+ value = Log(value)
+ self._branches[self.current_branch] = value
+
@property
def logfile(self) -> Path:
if self.current_branch == "main":
return get_logs_dir() / self.name / "conversation.jsonl"
return self.logdir / "branches" / f"{self.current_branch}.jsonl"
- def __getitem__(self, key):
- return self.log[key]
-
- def __len__(self):
- return len(self.log)
-
- def __iter__(self):
- return iter(self.log)
-
- def __bool__(self):
- return bool(self.log)
-
def append(self, msg: Message) -> None:
"""Appends a message to the log, writes the log, prints the message."""
- self.log.append(msg)
+ self.log = self.log.append(msg)
self.write()
if not msg.quiet:
print_msg(msg, oneline=False)
@@ -101,48 +130,47 @@ def write(self, branches=True) -> None:
Path(self.logfile).parent.mkdir(parents=True, exist_ok=True)
# write current branch
- _write_jsonl(self.logfile, self.log)
+ self.log.write_jsonl(self.logfile)
# write other branches
# FIXME: wont write main branch if on a different branch
if branches:
branches_dir = self.logdir / "branches"
branches_dir.mkdir(parents=True, exist_ok=True)
- for branch, msgs in self._branches.items():
+ for branch, log in self._branches.items():
if branch == "main":
continue
branch_path = branches_dir / f"{branch}.jsonl"
- _write_jsonl(branch_path, msgs)
-
- def print(self, show_hidden: bool | None = None):
- print_msg(self.log, oneline=False, show_hidden=show_hidden or self.show_hidden)
+ log.write_jsonl(branch_path)
def _save_backup_branch(self, type="edit") -> None:
"""backup the current log to a new branch, usually before editing/undoing"""
branch_prefix = f"{self.current_branch}-{type}-"
n = len([b for b in self._branches.keys() if b.startswith(branch_prefix)])
- self._branches[f"{branch_prefix}{n}"] = copy(self.log)
+ self._branches[f"{branch_prefix}{n}"] = self.log
self.write()
- def edit(self, new_log: list[Message]) -> None:
+ def edit(self, new_log: Log | list[Message]) -> None:
"""Edits the log."""
+ if isinstance(new_log, list):
+ new_log = Log(new_log)
self._save_backup_branch(type="edit")
- self._branches[self.current_branch] = new_log
+ self.log = new_log
self.write()
def undo(self, n: int = 1, quiet=False) -> None:
"""Removes the last message from the log."""
- undid = self[-1] if self.log else None
+ undid = self.log[-1] if self.log else None
if undid and undid.content.startswith("/undo"):
- self.log.pop()
+ self.log = self.log.pop()
# don't save backup branch if undoing a command
- if not self[-1].content.startswith("/"):
+ if self.log and not self.log[-1].content.startswith("/"):
self._save_backup_branch(type="undo")
# Doesn't work for multiple undos in a row, but useful in testing
# assert undid.content == ".undo" # assert that the last message is an undo
- peek = self[-1] if self.log else None
+ peek = self.log[-1] if self.log else None
if not peek:
print("[yellow]Nothing to undo.[/]")
return
@@ -150,28 +178,13 @@ def undo(self, n: int = 1, quiet=False) -> None:
if not quiet:
print("[yellow]Undoing messages:[/yellow]")
for _ in range(n):
- undid = self.log.pop()
+ undid = self.log[-1]
+ self.log = self.log.pop()
if not quiet:
print(
f"[red] {undid.role}: {textwrap.shorten(undid.content.strip(), width=50, placeholder='...')}[/]",
)
- peek = self[-1] if self.log else None
-
- def prepare_messages(self) -> list[Message]:
- """Prepares the log into messages before sending it to the LLM."""
- msgs = self.log
- msgs_reduced = list(reduce_log(msgs))
-
- if len_tokens(msgs) != len_tokens(msgs_reduced):
- logger.info(
- f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens"
- )
- msgs_limited = limit_log(msgs_reduced)
- if len(msgs_reduced) != len(msgs_limited):
- logger.info(
- f"Limited log from {len(msgs_reduced)} to {len(msgs_limited)} messages"
- )
- return msgs_limited
+ peek = self.log[-1] if self.log else None
@classmethod
def load(
@@ -202,13 +215,12 @@ def load(
if create:
logger.debug(f"Creating new logfile {logfile}")
Path(logfile).parent.mkdir(parents=True, exist_ok=True)
- _write_jsonl(logfile, [])
+ Log([]).write_jsonl(logfile)
else:
raise FileNotFoundError(f"Could not find logfile {logfile}")
- msgs = _read_jsonl(logfile)
- if not msgs:
- msgs = initial_msgs or [get_prompt()]
+ log = Log.read_jsonl(logfile)
+ msgs = log.messages or initial_msgs or [get_prompt()]
return cls(msgs, logdir=logdir, branch=branch, **kwargs)
def branch(self, name: str) -> None:
@@ -216,7 +228,7 @@ def branch(self, name: str) -> None:
self.write()
if name not in self._branches:
logger.info(f"Creating a new branch '{name}'")
- self._branches[name] = copy(self.log)
+ self._branches[name] = self.log
self.current_branch = name
def diff(self, branch: str) -> str | None:
@@ -292,6 +304,22 @@ def to_dict(self, branches=False) -> dict:
return d
+def prepare_messages(msgs: list[Message]) -> list[Message]:
+ """Prepares the messages before sending to the LLM."""
+ msgs_reduced = list(reduce_log(msgs))
+
+ if len_tokens(msgs) != len_tokens(msgs_reduced):
+ logger.info(
+ f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens"
+ )
+ msgs_limited = limit_log(msgs_reduced)
+ if len(msgs_reduced) != len(msgs_limited):
+ logger.info(
+ f"Limited log from {len(msgs_reduced)} to {len(msgs_limited)} messages"
+ )
+ return msgs_limited
+
+
def _conversation_files() -> list[Path]:
# NOTE: only returns the main conversation, not branches (to avoid duplicates)
# returns the conversation files sorted by modified time (newest first)
@@ -302,7 +330,7 @@ def _conversation_files() -> list[Path]:
@dataclass(frozen=True)
-class Conversation:
+class ConversationMeta:
name: str
path: str
created: float
@@ -311,16 +339,16 @@ class Conversation:
branches: int
-def get_conversations() -> Generator[Conversation, None, None]:
+def get_conversations() -> Generator[ConversationMeta, None, None]:
"""Returns all conversations, excluding ones used for testing, evals, etc."""
for conv_fn in _conversation_files():
- msgs = _read_jsonl(conv_fn, limit=1)
+ log = Log.read_jsonl(conv_fn, limit=1)
# TODO: can we avoid reading the entire file? maybe wont even be used, due to user convo filtering
len_msgs = conv_fn.read_text().count("}\n{")
- assert len(msgs) <= 1
+ assert len(log) <= 1
modified = conv_fn.stat().st_mtime
- first_timestamp = msgs[0].timestamp.timestamp() if msgs else modified
- yield Conversation(
+ first_timestamp = log[0].timestamp.timestamp() if log else modified
+ yield ConversationMeta(
name=f"{conv_fn.parent.name}",
path=str(conv_fn),
created=first_timestamp,
@@ -330,7 +358,7 @@ def get_conversations() -> Generator[Conversation, None, None]:
)
-def get_user_conversations() -> Generator[Conversation, None, None]:
+def get_user_conversations() -> Generator[ConversationMeta, None, None]:
"""Returns all user conversations, excluding ones used for testing, evals, etc."""
for conv in get_conversations():
if any(conv.name.startswith(prefix) for prefix in ["tmp", "test-"]) or any(
@@ -348,16 +376,3 @@ def _gen_read_jsonl(path: PathLike) -> Generator[Message, None, None]:
if "timestamp" in json_data:
json_data["timestamp"] = datetime.fromisoformat(json_data["timestamp"])
yield Message(**json_data, files=files)
-
-
-def _read_jsonl(path: PathLike, limit=None) -> list[Message]:
- gen = _gen_read_jsonl(path)
- if limit:
- gen = islice(gen, limit) # type: ignore
- return list(gen)
-
-
-def _write_jsonl(path: PathLike, msgs: list[Message]) -> None:
- with open(path, "w") as file:
- for msg in msgs:
- file.write(json.dumps(msg.to_dict()) + "\n")
diff --git a/gptme/message.py b/gptme/message.py
index 2b6f98cd..1246d9fb 100644
--- a/gptme/message.py
+++ b/gptme/message.py
@@ -172,6 +172,11 @@ def to_dict(self, keys=None, openai=False, anthropic=False) -> dict:
return {k: d[k] for k in keys}
return d
+ def to_xml(self) -> str:
+ """Converts a message to an XML string."""
+ attrs = f"role='{self.role}'"
+ return f"\n{self.content}\n"
+
def format(self, oneline: bool = False, highlight: bool = False) -> str:
return format_msgs([self], oneline=oneline, highlight=highlight)[0]
diff --git a/gptme/server/api.py b/gptme/server/api.py
index 4061f51a..e0c6450c 100644
--- a/gptme/server/api.py
+++ b/gptme/server/api.py
@@ -18,7 +18,7 @@
from ..commands import execute_cmd
from ..dirs import get_logs_dir
from ..llm import reply
-from ..logmanager import LogManager, get_user_conversations
+from ..logmanager import LogManager, get_user_conversations, prepare_messages
from ..message import Message
from ..models import get_model
from ..tools import execute_msg
@@ -91,26 +91,26 @@ def api_conversation_generate(logfile: str):
model = req_json.get("model", get_model().model)
# load conversation
- log = LogManager.load(logfile, branch=req_json.get("branch", "main"))
+ manager = LogManager.load(logfile, branch=req_json.get("branch", "main"))
# if prompt is a user-command, execute it
- if log[-1].role == "user":
+ if manager.log[-1].role == "user":
# TODO: capture output of command and return it
f = io.StringIO()
print("Begin capturing stdout, to pass along command output.")
with redirect_stdout(f):
- resp = execute_cmd(log[-1], log)
+ resp = execute_cmd(manager.log[-1], manager)
print("Done capturing stdout.")
if resp:
- log.write()
+ manager.write()
output = f.getvalue()
return flask.jsonify(
[{"role": "system", "content": output, "stored": False}]
)
# performs reduction/context trimming, if necessary
- msgs = log.prepare_messages()
+ msgs = prepare_messages(manager.log.messages)
# generate response
# TODO: add support for streaming
@@ -119,10 +119,10 @@ def api_conversation_generate(logfile: str):
# log response and run tools
resp_msgs = []
- log.append(msg)
+ manager.append(msg)
resp_msgs.append(msg)
for reply_msg in execute_msg(msg, ask=False):
- log.append(reply_msg)
+ manager.append(reply_msg)
resp_msgs.append(reply_msg)
return flask.jsonify(
diff --git a/gptme/tools/chats.py b/gptme/tools/chats.py
index 991b3f02..a009fda5 100644
--- a/gptme/tools/chats.py
+++ b/gptme/tools/chats.py
@@ -47,7 +47,7 @@ def _summarize_conversation(
summary_lines = []
if include_summary:
- summary = llm_summarize(log_manager.log)
+ summary = llm_summarize(log_manager.log.messages)
summary_lines.append(indent(f"Summary: {summary.content}", " "))
else:
non_system_messages = [msg for msg in log_manager.log if msg.role != "system"]
diff --git a/scripts/list_user_messages.py b/scripts/list_user_messages.py
index 66a49c38..4459b2fb 100644
--- a/scripts/list_user_messages.py
+++ b/scripts/list_user_messages.py
@@ -1,21 +1,20 @@
import logging
from datetime import datetime
-from gptme.logmanager import Conversation, _read_jsonl, get_user_conversations
+from gptme.logmanager import ConversationMeta, Log, get_user_conversations
# Set up logging
logging.basicConfig(level=logging.ERROR)
-def print_user_messages(conv: Conversation):
+def print_user_messages(conv: ConversationMeta):
"""
Print all user messages from a single conversation.
:param conversation: A dictionary containing conversation details
"""
lines = []
- msgs = _read_jsonl(conv.path)
- for message in msgs:
+ for message in Log.read_jsonl(conv.path):
if message.role == "user":
first_line = message.content.split("\n")[0]
if first_line.startswith(""):
diff --git a/scripts/treeofthoughts.py b/scripts/treeofthoughts.py
new file mode 100644
index 00000000..fdb01233
--- /dev/null
+++ b/scripts/treeofthoughts.py
@@ -0,0 +1,93 @@
+"""
+Tree-branching conversations for gptme with branch evaluation/prediction.
+
+The idea is to evaluate if we are on the right track by checking if the current branch is "good"/making progress, and otherwise backtracking to the last good branch and trying a different prompt/approach.
+"""
+
+import sys
+from typing import Literal
+
+from gptme.chat import step as _step
+from gptme.init import init
+from gptme.logmanager import Log
+from gptme.message import Message
+from gptme.prompts import get_prompt
+from lxml import etree
+
+EvalAction = Literal["continue", "undo", "done"]
+
+
+def step(log: Log) -> Log:
+ # Steps the conversation forward
+ for msg in _step(log, no_confirm=True):
+ log = log.append(msg)
+ return log
+
+
+def recommendation(log: Log) -> EvalAction:
+ # Returns a LLM-guided recommendation for the next action
+ # Can be: undo (backtrack), restart, continue,
+ system_msg = Message(
+ "system",
+ """
+We are evaluating the agent in the following conversation to determine the next action.
+
+Please evaluate the current state of the conversation,
+if the agent is making progress or if we should undo,
+and then recommend the next action within tags.
+
+For example:
+continue to let the agent continue (making progress)
+undo to backtrack until last user prompt (made a mistake)
+done if the agent has completed the task (e.g. answered the question)
+""",
+ )
+ log_xml = "Here is the conversation to evaluate:\n"
+ for msg in log:
+ log_xml += msg.to_xml() + "\n"
+ log = Log(
+ [system_msg]
+ + [Message("system", log_xml)]
+ + [Message("user", "evaluate the agent")]
+ )
+ log = step(log) # TODO: use faster model for this
+ parser = etree.HTMLParser()
+ tree = etree.fromstring(log[-1].content, parser)
+ return tree.xpath("//action")[0].text
+
+
+print("Init...")
+init(
+ model="openai/gpt-4o",
+ interactive=False,
+ tool_allowlist=["python", "shell", "save", "patch"],
+)
+
+# Set up the conversation
+prompt = sys.argv[1] if len(sys.argv) > 1 else "What is fib(10)?"
+prompts = [Message("user", prompt)]
+initial_msgs = [get_prompt("full", interactive=False)]
+log = Log(initial_msgs + prompts)
+
+while True:
+ # Step it forward
+ print("Stepping...")
+ log = step(log)
+ print("Done with step")
+
+ # Evaluate the conversation
+ action = recommendation(log)
+ print(f"Recommendation: {action}")
+
+ # Take the recommended action
+ if action == "continue":
+ continue
+ elif action == "undo":
+ log = log.pop()
+ elif action == "done":
+ break
+ else:
+ raise ValueError(f"Invalid action: {action}")
+
+# Print the final conversation
+log.print()