Log tool stuff to the database, refs #1003

This commit is contained in:
Simon Willison 2025-05-12 10:02:19 -07:00
parent 4abd6e0faf
commit 32cab986ea
4 changed files with 150 additions and 22 deletions

View file

@ -743,9 +743,7 @@ def prompt(
extra_tools = _tools_from_code(python_tools)
if tools or python_tools:
prompt_method = lambda *args, **kwargs: conversation.chain(
*args, **kwargs
).details()
prompt_method = conversation.chain
# prompt_method = conversation.chain
# Look up all those tools
registered_tools: dict = get_tools()
@ -825,25 +823,29 @@ def prompt(
raise
raise click.ClickException(str(ex))
if not isinstance(response, (Response, AsyncResponse)):
# Terminate early, logging for tool streaming mechanism not implemented yet
return
if isinstance(response, ChainResponse):
responses = response._responses
else:
responses = [response]
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
for response in responses:
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
if usage:
# Show token usage to stderr in yellow
click.echo(
click.style(
"Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
),
err=True,
)
if usage:
# Show token usage to stderr in yellow
click.echo(
click.style(
"Token usage: {}".format(response.token_usage()),
fg="yellow",
bold=True,
),
err=True,
)
# Log to the database
if (logs_on() or log) and not no_log:
response.log_to_db(db)
# Log to the database
if (logs_on() or log) and not no_log:
response.log_to_db(db)
@cli.command()

View file

@ -310,3 +310,61 @@ def m016_fragments_table_pks(db):
# https://github.com/simonw/llm/issues/863#issuecomment-2781720064
db["prompt_fragments"].transform(pk=("response_id", "fragment_id", "order"))
db["system_fragments"].transform(pk=("response_id", "fragment_id", "order"))
@migration
def m017_tools_tables(db):
db["tools"].create(
{
"id": int,
"hash": str,
"name": str,
"description": str,
"input_schema": str,
},
pk="id",
)
db["tools"].create_index(["hash"], unique=True)
# Many-to-many relationship between tools and responses
db["tool_responses"].create(
{
"tool_id": int,
"response_id": str,
},
foreign_keys=(
("tool_id", "tools", "id"),
("response_id", "responses", "id"),
),
pk=("tool_id", "response_id"),
)
# tool_calls and tool_results are one-to-many against responses
db["tool_calls"].create(
{
"id": int,
"response_id": str,
"tool_id": int,
"name": str,
"arguments": str,
"tool_call_id": str,
},
pk="id",
foreign_keys=(
("response_id", "responses", "id"),
("tool_id", "tools", "id"),
),
)
db["tool_results"].create(
{
"id": int,
"response_id": str,
"tool_id": int,
"name": str,
"output": str,
"tool_call_id": str,
},
pk="id",
foreign_keys=(
("response_id", "responses", "id"),
("tool_id", "tools", "id"),
),
)

View file

@ -24,6 +24,7 @@ from typing import (
)
from .utils import (
ensure_fragment,
ensure_tool,
make_schema_id,
mimetype_from_path,
mimetype_from_string,
@ -127,6 +128,15 @@ class Tool:
return schema_dict
return schema
def hash(self):
"""Hash for tool based on its name, description and input schema (preserving key order)"""
to_hash = {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema,
}
return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest()
@classmethod
def function(cls, function, name=None):
"""
@ -597,6 +607,10 @@ class _BaseResponse:
response_text = self.text_or_raise()
replacements[f"r:{response_id}"] = response_text
json_data = self.json()
# Temporary workraound TODO remove this
json_data = {}
response = {
"id": response_id,
"model": self.model.model_id,
@ -643,6 +657,38 @@ class _BaseResponse:
},
)
# Persist any tools, tool calls and tool results
tool_ids_by_name = {}
for tool in self.prompt.tools:
tool_id = ensure_tool(db, tool)
tool_ids_by_name[tool.name] = tool_id
db["tool_responses"].insert(
{
"tool_id": tool_id,
"response_id": response_id,
}
)
for tool_call in self.tool_calls(): # TODO Should be _or_raise()
db["tool_calls"].insert(
{
"response_id": response_id,
"tool_id": tool_ids_by_name.get(tool_call.name) or None,
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
"tool_call_id": tool_call.tool_call_id,
}
)
for tool_result in self.prompt.tool_results:
db["tool_results"].insert(
{
"response_id": response_id,
"tool_id": tool_ids_by_name.get(tool_result.name) or None,
"name": tool_result.name,
"output": tool_result.output,
"tool_call_id": tool_result.tool_call_id,
}
)
class Response(_BaseResponse):
model: "Model"
@ -942,6 +988,7 @@ class _BaseChainResponse:
while response:
count += 1
yield response
self._responses.append(response)
if count > self.chain_limit:
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
tool_calls = response.tool_calls()

View file

@ -457,14 +457,35 @@ def ensure_fragment(db, content):
values (:hash, :content, datetime('now'), :source)
on conflict(hash) do nothing
"""
hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
hash_id = hashlib.sha256(content.encode("utf-8")).hexdigest()
source = None
if isinstance(content, Fragment):
source = content.source
with db.conn:
db.execute(sql, {"hash": hash, "content": content, "source": source})
db.execute(sql, {"hash": hash_id, "content": content, "source": source})
return list(
db.query("select id from fragments where hash = :hash", {"hash": hash})
db.query("select id from fragments where hash = :hash", {"hash": hash_id})
)[0]["id"]
def ensure_tool(db, tool):
sql = """
insert into tools (hash, name, description, input_schema)
values (:hash, :name, :description, :input_schema)
on conflict(hash) do nothing
"""
with db.conn:
db.execute(
sql,
{
"hash": tool.hash(),
"name": tool.name,
"description": tool.description,
"input_schema": json.dumps(tool.input_schema),
},
)
return list(
db.query("select id from tools where hash = :hash", {"hash": tool.hash()})
)[0]["id"]