From a50de8b57a8673b6100c83ea313b4a9c554ac9b9 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 13 May 2025 09:41:46 -0700 Subject: [PATCH] ChainResponse.log_to_db() method and test, refs #1017, #1003 --- llm/cli.py | 24 ++++++++++-------------- llm/models.py | 6 ++++++ tests/test_tools.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index c8e29e7..c22ef1c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -875,29 +875,25 @@ def prompt( raise raise click.ClickException(str(ex)) - if isinstance(response, ChainResponse): - responses = response._responses - else: - responses = [response] - - for response in responses: - if isinstance(response, AsyncResponse): - response = asyncio.run(response.to_sync_response()) - - if usage: + if usage: + if isinstance(response, ChainResponse): + responses = response._responses + else: + responses = [response] + for response_object in responses: # Show token usage to stderr in yellow click.echo( click.style( - "Token usage: {}".format(response.token_usage()), + "Token usage: {}".format(response_object.token_usage()), fg="yellow", bold=True, ), err=True, ) - # Log to the database - if (logs_on() or log) and not no_log: - response.log_to_db(db) + # Log to the database + if (logs_on() or log) and not no_log: + response.log_to_db(db) @cli.command() diff --git a/llm/models.py b/llm/models.py index 41f009b..c90db6b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1134,6 +1134,12 @@ class _BaseChainResponse: def text(self) -> str: return "".join(self) + def log_to_db(self, db): + for response in self._responses: + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) + response.log_to_db(db) + class ChainResponse(_BaseChainResponse): "Know how to chain multiple responses e.g. for tool calls" diff --git a/tests/test_tools.py b/tests/test_tools.py index a23f471..0038fab 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,6 +1,8 @@ import llm +from llm.migrations import migrate import os import pytest +import sqlite_utils API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey" @@ -28,3 +30,33 @@ def test_tool_use_basic(vcr): assert len(second.prompt.tool_results) == 1 assert second.prompt.tool_results[0].name == "multiply" assert second.prompt.tool_results[0].output == "2869461" + + # Test writing to the database + db = sqlite_utils.Database(memory=True) + migrate(db) + chain_response.log_to_db(db) + assert set(db.table_names()).issuperset( + {"tools", "tool_responses", "tool_calls", "tool_results"} + ) + + responses = list(db["responses"].rows) + assert len(responses) == 2 + first_response, second_response = responses + + tools = list(db["tools"].rows) + assert len(tools) == 1 + assert tools[0]["name"] == "multiply" + assert tools[0]["description"] == "Multiply two numbers." + + tool_results = list(db["tool_results"].rows) + tool_calls = list(db["tool_calls"].rows) + + assert len(tool_calls) == 1 + assert tool_calls[0]["response_id"] == first_response["id"] + assert tool_calls[0]["name"] == "multiply" + assert tool_calls[0]["arguments"] == '{"a": 1231, "b": 2331}' + + assert len(tool_results) == 1 + assert tool_results[0]["response_id"] == second_response["id"] + assert tool_results[0]["output"] == "2869461" + assert tool_results[0]["tool_call_id"] == tool_calls[0]["tool_call_id"]