Skip to content

Commit

Permalink
feat: added list_chats to chats tool, and cleaned up/refactored non-T…
Browse files Browse the repository at this point in the history
…oolSpec-using tools (#110)
  • Loading branch information
ErikBjare authored Sep 7, 2024
1 parent 4e36109 commit 5cb3936
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 310 deletions.
33 changes: 4 additions & 29 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ The main tools can be grouped in the following categories:

- chat management

- :ref:`Edit`
- :ref:`Reduce`
- :ref:`Context`
- :ref:`Summarize`
- :ref:`Chats`

Shell
-----
Expand Down Expand Up @@ -69,31 +66,9 @@ Browser
:members:
:noindex:

Edit
----

.. automodule:: gptme.tools.useredit
:members:
:noindex:

Reduce
------

.. automodule:: gptme.tools.reduce
:members:
:noindex:

Context
-------

.. automodule:: gptme.tools.context
:members:
:noindex:

Summarize
---------
Chats
-----

.. automodule:: gptme.tools.summarize
.. automodule:: gptme.tools.chats
:members:
:noindex:

10 changes: 2 additions & 8 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@
execute_shell,
loaded_tools,
)
from .tools.context import gen_context_msg
from .tools.summarize import summarize
from .tools.useredit import edit_text_with_editor
from .useredit import edit_text_with_editor
from .util import ask_execute

logger = logging.getLogger(__name__)

Actions = Literal[
"summarize",
"log",
"edit",
"rename",
Expand Down Expand Up @@ -114,16 +111,13 @@ def handle_cmd(
case "summarize":
msgs = log.prepare_messages()
msgs = [m for m in msgs if not m.hide]
summary = summarize(msgs)
summary = llm.summarize(msgs)
print(f"Summary: {summary}")
case "edit":
# edit previous messages
# first undo the '/edit' command itself
log.undo(1, quiet=True)
yield from edit(log)
case "context":
# print context msg
yield gen_context_msg()
case "undo":
# undo the '/undo' command itself
log.undo(1, quiet=True)
Expand Down
39 changes: 37 additions & 2 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import sys
from collections.abc import Iterator
from functools import lru_cache
from typing import Literal

from rich import print
Expand All @@ -16,7 +17,7 @@
from .llm_openai import get_client as get_openai_client
from .llm_openai import init as init_openai
from .llm_openai import stream as stream_openai
from .message import Message, len_tokens
from .message import Message, format_msgs, len_tokens
from .models import MODELS, get_summary_model
from .util import extract_codeblocks

Expand Down Expand Up @@ -125,7 +126,7 @@ def _client_to_provider() -> Provider:
raise ValueError("Unknown client type")


def summarize(content: str) -> str:
def _summarize_str(content: str) -> str:
"""
Summarizes a long text using a LLM.
Expand Down Expand Up @@ -186,3 +187,37 @@ def generate_name(msgs: list[Message]) -> str:
)
name = _chat_complete(msgs, model=get_summary_model(_client_to_provider())).strip()
return name


def summarize(msg: str | Message | list[Message]) -> Message:
"""Uses a cheap LLM to summarize long outputs."""
# construct plaintext from message(s)
if isinstance(msg, str):
content = msg
elif isinstance(msg, Message):
content = msg.content
else:
content = "\n".join(format_msgs(msg))

logger.info(f"{content[:200]=}")
summary = _summarize_helper(content)
logger.info(f"{summary[:200]=}")

# construct message from summary
content = f"Here's a summary of the conversation:\n{summary}"
return Message(role="system", content=content)


@lru_cache(maxsize=128)
def _summarize_helper(s: str, tok_max_start=400, tok_max_end=400) -> str:
"""
Helper function for summarizing long outputs.
Truncates long outputs, then summarizes.
"""
if len_tokens(s) > tok_max_start + tok_max_end:
beginning = " ".join(s.split()[:tok_max_start])
end = " ".join(s.split()[-tok_max_end:])
summary = _summarize_str(beginning + "\n...\n" + end)
else:
summary = _summarize_str(s)
return summary
2 changes: 1 addition & 1 deletion gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .dirs import get_logs_dir
from .message import Message, len_tokens, print_msg
from .prompts import get_prompt
from .tools.reduce import limit_log, reduce_log
from .reduce import limit_log, reduce_log

PathLike: TypeAlias = str | Path

Expand Down
4 changes: 2 additions & 2 deletions gptme/tools/reduce.py → gptme/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from collections.abc import Generator
from copy import copy

from ..message import Message, len_tokens
from ..models import get_model
from .message import Message, len_tokens
from .models import get_model

logger = logging.getLogger(__name__)

Expand Down
6 changes: 2 additions & 4 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
from ..util import extract_codeblocks
from .base import ToolSpec
from .browser import tool as browser_tool
from .chats import tool as chats_tool
from .gh import tool as gh_tool
from .patch import tool as patch_tool
from .python import execute_python
from .python import get_tool as get_python_tool
from .python import register_function
from .read import tool as tool_read
from .save import execute_save, tool_append, tool_save
from .search_chats import tool as search_chats_tool
from .shell import execute_shell
from .shell import tool as shell_tool
from .subagent import tool as subagent_tool
from .summarize import summarize
from .tmux import tool as tmux_tool

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +28,6 @@
"execute_python",
"execute_shell",
"execute_save",
"summarize",
"ToolSpec",
"ToolUse",
"all_tools",
Expand All @@ -45,7 +43,7 @@
tmux_tool,
browser_tool,
gh_tool,
search_chats_tool,
chats_tool,
# python tool is loaded last to ensure all functions are registered
get_python_tool,
]
Expand Down
167 changes: 167 additions & 0 deletions gptme/tools/chats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
List, search, and summarize past conversation logs.
"""

import logging
from pathlib import Path
from textwrap import indent

from ..llm import summarize as llm_summarize
from ..message import Message
from ..util import transform_examples_to_chat_directives
from .base import ToolSpec

logger = logging.getLogger(__name__)


def _format_message_snippet(msg: Message, max_length: int = 100) -> str:
"""Format a message snippet for display."""
first_newline = msg.content.find("\n")
max_length = min(max_length, first_newline) if first_newline != -1 else max_length
content = msg.content[:max_length]
return f"{msg.role.capitalize()}: {content}" + (
"..." if len(content) <= len(msg.content) else ""
)


def _get_matching_messages(log_manager, query: str) -> list[Message]:
"""Get messages matching the query."""
return [msg for msg in log_manager.log if query.lower() in msg.content.lower()]


def _summarize_conversation(log_manager, include_summary: bool) -> list[str]:
"""Summarize a conversation."""
summary_lines = []
if include_summary:
summary = llm_summarize(log_manager.log)
summary_lines.append(indent(f"Summary: {summary.content}", " "))
else:
non_system_messages = [msg for msg in log_manager.log if msg.role != "system"]
if non_system_messages:
first_msg = non_system_messages[0]
last_msg = non_system_messages[-1]

summary_lines.append(
f" First message: {_format_message_snippet(first_msg)}"
)
if last_msg != first_msg:
summary_lines.append(
f" Last message: {_format_message_snippet(last_msg)}"
)

summary_lines.append(f" Total messages: {len(log_manager.log)}")
return summary_lines


def list_chats(max_results: int = 5, include_summary: bool = True) -> None:
"""
List recent chat conversations and optionally summarize them.
Args:
max_results (int): Maximum number of conversations to display.
include_summary (bool): Whether to include a summary of each conversation.
If True, uses an LLM to generate a comprehensive summary.
If False, uses a simple strategy showing snippets of the first and last messages.
"""
# noreorder
from ..logmanager import LogManager, get_conversations # fmt: skip

conversations = list(get_conversations())[:max_results]

if not conversations:
print("No conversations found.")
return

print(f"Recent conversations (showing up to {max_results}):")
for i, conv in enumerate(conversations, 1):
print(f"\n{i}. {conv['name']}")
if "created_at" in conv:
print(f" Created: {conv['created_at']}")

log_path = Path(conv["path"])
log_manager = LogManager.load(log_path)

summary_lines = _summarize_conversation(log_manager, include_summary)
print("\n".join(summary_lines))


def search_chats(query: str, max_results: int = 5) -> None:
"""
Search past conversation logs for the given query and print a summary of the results.
Args:
query (str): The search query.
max_results (int): Maximum number of conversations to display.
"""
# noreorder
from ..logmanager import LogManager, get_conversations # fmt: skip

conversations = list(get_conversations())
results = []

for conv in conversations:
log_path = Path(conv["path"])
log_manager = LogManager.load(log_path)

matching_messages = _get_matching_messages(log_manager, query)

if matching_messages:
results.append(
{
"conversation": conv["name"],
"log_manager": log_manager,
"matching_messages": matching_messages,
}
)

# Sort results by the number of matching messages, in descending order
results.sort(key=lambda x: len(x["matching_messages"]), reverse=True)
results = results[:max_results]

if not results:
print(f"No results found for query: '{query}'")
return

print(f"Search results for query: '{query}'")
print(f"Found matches in {len(results)} conversation(s):")

for i, result in enumerate(results, 1):
print(f"\n{i}. Conversation: {result['conversation']}")
print(f" Number of matching messages: {len(result['matching_messages'])}")

summary_lines = _summarize_conversation(
result["log_manager"], include_summary=False
)
print("\n".join(summary_lines))

print(" Sample matches:")
for j, msg in enumerate(result["matching_messages"][:3], 1):
print(f" {j}. {_format_message_snippet(msg)}")
if len(result["matching_messages"]) > 3:
print(
f" ... and {len(result['matching_messages']) - 3} more matching message(s)"
)


instructions = """
The chats tool allows you to list, search, and summarize past conversation logs.
"""

examples = """
### Search for a specific topic in past conversations
User: Can you find any mentions of "python" in our past conversations?
Assistant: Certainly! I'll search our past conversations for mentions of "python" using the search_chats function.
```python
search_chats("python")
```
"""

__doc__ += transform_examples_to_chat_directives(examples)

tool = ToolSpec(
name="chats",
desc="List, search, and summarize past conversation logs",
instructions=instructions,
examples=examples,
functions=[list_chats, search_chats],
)
Loading

0 comments on commit 5cb3936

Please sign in to comment.