From 178af27d95801097485e4e28d50cd26a530eecfc Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 15 Jul 2023 16:20:28 -0700 Subject: [PATCH] llm logs list -q/--query option, closes #109 --- docs/help.md | 1 + docs/logging.md | 4 +++ llm/cli.py | 77 ++++++++++++++++++++++++++++++----------------- tests/test_llm.py | 34 +++++++++++++++++++++ 4 files changed, 89 insertions(+), 27 deletions(-) diff --git a/docs/help.md b/docs/help.md index 2a3221df..e53d5ab6 100644 --- a/docs/help.md +++ b/docs/help.md @@ -187,6 +187,7 @@ Options: -n, --count INTEGER Number of entries to show - 0 for all -p, --path FILE Path to log database -m, --model TEXT Filter by model or model alias + -q, --query TEXT Search for logs matching this string -t, --truncate Truncate long strings in output --help Show this message and exit. ``` diff --git a/docs/logging.md b/docs/logging.md index e2f85e6e..51572223 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -62,6 +62,10 @@ Or `-n 0` to see everything that has ever been logged: ```bash llm logs -n 0 ``` +You can search the logs for a search term in the `prompt` or the `response` columns: +```bash +llm logs -q 'cheesecake' +``` You can filter to logs just for a specific model (or model alias) using `-m/--model`: ```bash llm logs -m chatgpt diff --git a/llm/cli.py b/llm/cli.py index b40731aa..ac3b3dc2 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -372,6 +372,40 @@ def logs_turn_off(): path.touch() +LOGS_COLUMNS = """ responses.id, + responses.model, + responses.prompt, + responses.system, + responses.prompt_json, + responses.options_json, + responses.response, + responses.response_json, + responses.conversation_id, + responses.duration_ms, + responses.datetime_utc, + conversations.name as conversation_name, + conversations.model as conversation_model""" + +LOGS_SQL = """ +select +{columns} +from + responses +left join conversations on responses.conversation_id = conversations.id{where} +order by responses.id desc{limit} +""" +LOGS_SQL_SEARCH = """ +select +{columns} +from + responses +left join conversations on responses.conversation_id = conversations.id +join responses_fts on responses_fts.rowid = responses.rowid +where responses_fts match :query{extra_where} +order by responses_fts.rank desc{limit} +""" + + @logs.command(name="list") @click.option( "-n", @@ -386,8 +420,9 @@ def logs_turn_off(): help="Path to log database", ) @click.option("-m", "--model", help="Filter by model or model alias") +@click.option("-q", "--query", help="Search for logs matching this string") @click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output") -def logs_list(count, path, model, truncate): +def logs_list(count, path, model, query, truncate): "Show recent logged prompts and their responses" path = pathlib.Path(path or logs_db_path()) if not path.exists(): @@ -404,33 +439,21 @@ def logs_list(count, path, model, truncate): # Maybe they uninstalled a model, use the -m option as-is model_id = model - rows = list( - db.query( - """ - select - responses.id, - responses.model, - responses.prompt, - responses.system, - responses.prompt_json, - responses.options_json, - responses.response, - responses.response_json, - responses.conversation_id, - responses.duration_ms, - responses.datetime_utc, - conversations.name as conversation_name, - conversations.model as conversation_model - from - responses - left join conversations on responses.conversation_id = conversations.id{where} - order by responses.id desc{limit} - """.format( - where=" where responses.model = :model" if model_id else "", - limit=" limit {}".format(count) if count else "", - ), - {"model": model_id}, + sql = LOGS_SQL + format_kwargs = { + "limit": " limit {}".format(count) if count else "", + "columns": LOGS_COLUMNS, + } + if query: + sql = LOGS_SQL_SEARCH + format_kwargs["extra_where"] = ( + " and responses.model = :model" if model_id else "" ) + else: + format_kwargs["where"] = " where responses.model = :model" if model_id else "" + + rows = list( + db.query(sql.format(**format_kwargs), {"model": model_id, "query": query}) ) for row in rows: if truncate: diff --git a/tests/test_llm.py b/tests/test_llm.py index 6fc06b5d..4cff670f 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -85,6 +85,40 @@ def test_logs_filtered(user_path, model): assert all(record["model"] == model for record in records) +@pytest.mark.parametrize( + "query,expected", + ( + ("", ["doc3", "doc2", "doc1"]), + ("llama", ["doc1", "doc3"]), + ("alpaca", ["doc2"]), + ), +) +def test_logs_search(user_path, query, expected): + log_path = str(user_path / "logs.db") + db = sqlite_utils.Database(log_path) + migrate(db) + + def _insert(id, text): + db["responses"].insert( + { + "id": id, + "system": "system", + "prompt": text, + "response": "response", + "model": "davinci", + } + ) + + _insert("doc1", "llama") + _insert("doc2", "alpaca") + _insert("doc3", "llama llama") + runner = CliRunner() + result = runner.invoke(cli, ["logs", "list", "-q", query]) + assert result.exit_code == 0 + records = json.loads(result.output.strip()) + assert [record["id"] for record in records] == expected + + @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize("use_stdin", (True, False)) @pytest.mark.parametrize(