From 57c4348ea03e9a796a97458b90821da4b857967b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 11 Jul 2023 07:17:46 -0700 Subject: [PATCH] llm logs now uses new DB schema, refs #91 --- docs/logging.md | 9 +++--- llm/cli.py | 41 ++++++++++++++++++++++---- llm/models.py | 4 +-- tests/test_llm.py | 20 ++++--------- tests/test_migrate.py | 65 ++++++++++++++++++++++++++--------------- tests/test_templates.py | 3 +- 6 files changed, 90 insertions(+), 52 deletions(-) diff --git a/docs/logging.md b/docs/logging.md index b6382e1f..75a7efb9 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -62,7 +62,7 @@ import sqlite_utils import re db = sqlite_utils.Database(memory=True) migrate(db) -schema = db["logs"].schema +schema = db["responses"].schema def cleanup_sql(sql): first_line = sql.split('(')[0] @@ -75,8 +75,8 @@ cog.out( ) ]]] --> ```sql -CREATE TABLE "logs" ( - [id] INTEGER PRIMARY KEY, +CREATE TABLE [responses] ( + [id] TEXT PRIMARY KEY, [model] TEXT, [prompt] TEXT, [system] TEXT, @@ -84,8 +84,7 @@ CREATE TABLE "logs" ( [options_json] TEXT, [response] TEXT, [response_json] TEXT, - [reply_to_id] INTEGER REFERENCES [logs]([id]), - [chat_id] INTEGER REFERENCES [logs]([id]), + [conversation_id] TEXT REFERENCES [conversations]([id]), [duration_ms] INTEGER, [datetime_utc] TEXT ); diff --git a/llm/cli.py b/llm/cli.py index 8b5cac5c..66b0ba6d 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -22,7 +22,7 @@ import sqlite_utils import sys import textwrap -from typing import Optional +from typing import cast, Optional import warnings import yaml @@ -265,7 +265,7 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]: else: return None try: - row = db["conversations"].get(conversation_id) + row = cast(sqlite_utils.db.Table, db["conversations"]).get(conversation_id) except sqlite_utils.db.NotFoundError: raise click.ClickException( "No conversation found with id={}".format(conversation_id) @@ -368,15 +368,44 @@ def logs_list(count, path, truncate): raise click.ClickException("No log database found at {}".format(path)) db = sqlite_utils.Database(path) migrate(db) - rows = list(db["logs"].rows_where(order_by="-id", limit=count or None)) + rows = list( + db.query( + """ + select + responses.id, + responses.model, + responses.prompt, + responses.system, + responses.prompt_json, + responses.options_json, + responses.response, + responses.response_json, + responses.conversation_id, + responses.duration_ms, + responses.datetime_utc, + conversations.name as conversation_name, + conversations.model as conversation_model + from + responses + left join conversations on responses.conversation_id = conversations.id + order by responses.id desc{} + """.format( + " limit {}".format(count) if count else "" + ) + ) + ) for row in rows: if truncate: row["prompt"] = _truncate_string(row["prompt"]) row["response"] = _truncate_string(row["response"]) - # Decode all JSON keys - for key in row: + # Either decode or remove all JSON keys + keys = list(row.keys()) + for key in keys: if key.endswith("_json") and row[key] is not None: - row[key] = json.loads(row[key]) + if truncate: + del row[key] + else: + row[key] = json.loads(row[key]) click.echo(json.dumps(list(rows), indent=2)) diff --git a/llm/models.py b/llm/models.py index 43c99660..69132c6b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass, asdict, field +from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException import time -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, List, Optional, Set from abc import ABC, abstractmethod import os import json diff --git a/tests/test_llm.py b/tests/test_llm.py index b300839d..0d7a3810 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -5,6 +5,7 @@ import os import pytest import sqlite_utils +from ulid import ULID from unittest import mock @@ -18,11 +19,13 @@ def test_version(): @pytest.mark.parametrize("n", (None, 0, 2)) def test_logs(n, user_path): + "Test that logs command correctly returns requested -n records" log_path = str(user_path / "logs.db") db = sqlite_utils.Database(log_path) migrate(db) - db["logs"].insert_all( + db["responses"].insert_all( { + "id": str(ULID()).lower(), "system": "system", "prompt": "prompt", "response": "response", @@ -66,7 +69,7 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path): # Reset the log_path database log_path = user_path / "logs.db" log_db = sqlite_utils.Database(str(log_path)) - log_db["logs"].delete_where() + log_db["responses"].delete_where() runner = CliRunner() prompt = "three names for a pet pelican" input = None @@ -81,23 +84,14 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path): assert mocked_openai.last_request.headers["Authorization"] == "Bearer X" # Was it logged? - rows = list(log_db["logs"].rows) + rows = list(log_db["responses"].rows) assert len(rows) == 1 - expected = { - "model": "gpt-3.5-turbo", - "prompt": "three names for a pet pelican", - "system": None, - "response": "Bob, Alice, Eve", - "chat_id": None, - } expected = { "model": "gpt-3.5-turbo", "prompt": "three names for a pet pelican", "system": None, "options_json": "{}", "response": "Bob, Alice, Eve", - "reply_to_id": None, - "chat_id": None, } row = rows[0] assert expected.items() <= row.items() @@ -133,7 +127,5 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path): "usage": {}, "choices": [{"message": {"content": "Bob, Alice, Eve"}}], }, - "reply_to_id": None, - "chat_id": None, }.items() ) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 64ffa396..5fe3c3cb 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,9 +1,10 @@ from llm.migrations import migrate +import pytest import sqlite_utils EXPECTED = { - "id": int, + "id": str, "model": str, "prompt": str, "system": str, @@ -11,8 +12,7 @@ "options_json": str, "response": str, "response_json": str, - "reply_to_id": int, - "chat_id": int, + "conversation_id": str, "duration_ms": int, "datetime_utc": str, } @@ -21,34 +21,51 @@ def test_migrate_blank(): db = sqlite_utils.Database(memory=True) migrate(db) - assert set(db.table_names()) == {"_llm_migrations", "logs"} - assert db["logs"].columns_dict == EXPECTED + assert set(db.table_names()) == {"_llm_migrations", "conversations", "responses"} + assert db["responses"].columns_dict == EXPECTED - foreign_keys = db["logs"].foreign_keys + foreign_keys = db["responses"].foreign_keys for expected_fk in ( sqlite_utils.db.ForeignKey( - table="logs", column="reply_to_id", other_table="logs", other_column="id" - ), - sqlite_utils.db.ForeignKey( - table="logs", column="chat_id", other_table="logs", other_column="id" + table="responses", + column="conversation_id", + other_table="conversations", + other_column="id", ), ): assert expected_fk in foreign_keys -def test_migrate_from_original_schema(): +@pytest.mark.parametrize("has_record", [True, False]) +def test_migrate_from_original_schema(has_record): db = sqlite_utils.Database(memory=True) - db["log"].insert( - { - "provider": "provider", - "system": "system", - "prompt": "prompt", - "chat_id": None, - "response": "response", - "model": "model", - "timestamp": "timestamp", - }, - ) + if has_record: + db["log"].insert( + { + "provider": "provider", + "system": "system", + "prompt": "prompt", + "chat_id": None, + "response": "response", + "model": "model", + "timestamp": "timestamp", + }, + ) + else: + # Create empty logs table + db["log"].create( + { + "provider": str, + "system": str, + "prompt": str, + "chat_id": str, + "response": str, + "model": str, + "timestamp": str, + } + ) migrate(db) - assert set(db.table_names()) == {"_llm_migrations", "logs"} - assert db["logs"].columns_dict == EXPECTED + expected_tables = {"_llm_migrations", "conversations", "responses"} + if has_record: + expected_tables.add("logs") + assert set(db.table_names()) == expected_tables diff --git a/tests/test_templates.py b/tests/test_templates.py index 08e6c424..21d2e018 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -82,7 +82,8 @@ def test_templates_list(templates_path): (["--system", "system"], {"system": "system"}, None), (["-t", "template"], None, "--save cannot be used with --template"), (["--continue"], None, "--save cannot be used with --continue"), - (["--chat", "123"], None, "--save cannot be used with --chat"), + (["--cid", "123"], None, "--save cannot be used with --cid"), + (["--conversation", "123"], None, "--save cannot be used with --cid"), ( ["Say hello as $name", "-p", "name", "default-name"], {"prompt": "Say hello as $name", "defaults": {"name": "default-name"}},