From 32cab986ea28c2c36bd344763dba2bf7a5f4253e Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 12 May 2025 10:02:19 -0700 Subject: [PATCH] Log tool stuff to the database, refs #1003 --- llm/cli.py | 40 ++++++++++++++++---------------- llm/migrations.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++ llm/models.py | 47 ++++++++++++++++++++++++++++++++++++++ llm/utils.py | 27 +++++++++++++++++++--- 4 files changed, 150 insertions(+), 22 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index a5bc9e2..8210072 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -743,9 +743,7 @@ def prompt( extra_tools = _tools_from_code(python_tools) if tools or python_tools: - prompt_method = lambda *args, **kwargs: conversation.chain( - *args, **kwargs - ).details() + prompt_method = conversation.chain # prompt_method = conversation.chain # Look up all those tools registered_tools: dict = get_tools() @@ -825,25 +823,29 @@ def prompt( raise raise click.ClickException(str(ex)) - if not isinstance(response, (Response, AsyncResponse)): - # Terminate early, logging for tool streaming mechanism not implemented yet - return + if isinstance(response, ChainResponse): + responses = response._responses + else: + responses = [response] - if isinstance(response, AsyncResponse): - response = asyncio.run(response.to_sync_response()) + for response in responses: + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) - if usage: - # Show token usage to stderr in yellow - click.echo( - click.style( - "Token usage: {}".format(response.token_usage()), fg="yellow", bold=True - ), - err=True, - ) + if usage: + # Show token usage to stderr in yellow + click.echo( + click.style( + "Token usage: {}".format(response.token_usage()), + fg="yellow", + bold=True, + ), + err=True, + ) - # Log to the database - if (logs_on() or log) and not no_log: - response.log_to_db(db) + # Log to the database + if (logs_on() or log) and not no_log: + response.log_to_db(db) @cli.command() diff --git a/llm/migrations.py b/llm/migrations.py index a9c53c9..c6f9311 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -310,3 +310,61 @@ def m016_fragments_table_pks(db): # https://github.com/simonw/llm/issues/863#issuecomment-2781720064 db["prompt_fragments"].transform(pk=("response_id", "fragment_id", "order")) db["system_fragments"].transform(pk=("response_id", "fragment_id", "order")) + + +@migration +def m017_tools_tables(db): + db["tools"].create( + { + "id": int, + "hash": str, + "name": str, + "description": str, + "input_schema": str, + }, + pk="id", + ) + db["tools"].create_index(["hash"], unique=True) + # Many-to-many relationship between tools and responses + db["tool_responses"].create( + { + "tool_id": int, + "response_id": str, + }, + foreign_keys=( + ("tool_id", "tools", "id"), + ("response_id", "responses", "id"), + ), + pk=("tool_id", "response_id"), + ) + # tool_calls and tool_results are one-to-many against responses + db["tool_calls"].create( + { + "id": int, + "response_id": str, + "tool_id": int, + "name": str, + "arguments": str, + "tool_call_id": str, + }, + pk="id", + foreign_keys=( + ("response_id", "responses", "id"), + ("tool_id", "tools", "id"), + ), + ) + db["tool_results"].create( + { + "id": int, + "response_id": str, + "tool_id": int, + "name": str, + "output": str, + "tool_call_id": str, + }, + pk="id", + foreign_keys=( + ("response_id", "responses", "id"), + ("tool_id", "tools", "id"), + ), + ) diff --git a/llm/models.py b/llm/models.py index d33310b..b165113 100644 --- a/llm/models.py +++ b/llm/models.py @@ -24,6 +24,7 @@ from typing import ( ) from .utils import ( ensure_fragment, + ensure_tool, make_schema_id, mimetype_from_path, mimetype_from_string, @@ -127,6 +128,15 @@ class Tool: return schema_dict return schema + def hash(self): + """Hash for tool based on its name, description and input schema (preserving key order)""" + to_hash = { + "name": self.name, + "description": self.description, + "input_schema": self.input_schema, + } + return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest() + @classmethod def function(cls, function, name=None): """ @@ -597,6 +607,10 @@ class _BaseResponse: response_text = self.text_or_raise() replacements[f"r:{response_id}"] = response_text json_data = self.json() + + # Temporary workraound TODO remove this + json_data = {} + response = { "id": response_id, "model": self.model.model_id, @@ -643,6 +657,38 @@ class _BaseResponse: }, ) + # Persist any tools, tool calls and tool results + tool_ids_by_name = {} + for tool in self.prompt.tools: + tool_id = ensure_tool(db, tool) + tool_ids_by_name[tool.name] = tool_id + db["tool_responses"].insert( + { + "tool_id": tool_id, + "response_id": response_id, + } + ) + for tool_call in self.tool_calls(): # TODO Should be _or_raise() + db["tool_calls"].insert( + { + "response_id": response_id, + "tool_id": tool_ids_by_name.get(tool_call.name) or None, + "name": tool_call.name, + "arguments": json.dumps(tool_call.arguments), + "tool_call_id": tool_call.tool_call_id, + } + ) + for tool_result in self.prompt.tool_results: + db["tool_results"].insert( + { + "response_id": response_id, + "tool_id": tool_ids_by_name.get(tool_result.name) or None, + "name": tool_result.name, + "output": tool_result.output, + "tool_call_id": tool_result.tool_call_id, + } + ) + class Response(_BaseResponse): model: "Model" @@ -942,6 +988,7 @@ class _BaseChainResponse: while response: count += 1 yield response + self._responses.append(response) if count > self.chain_limit: raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ") tool_calls = response.tool_calls() diff --git a/llm/utils.py b/llm/utils.py index fb55e95..f5ac58d 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -457,14 +457,35 @@ def ensure_fragment(db, content): values (:hash, :content, datetime('now'), :source) on conflict(hash) do nothing """ - hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + hash_id = hashlib.sha256(content.encode("utf-8")).hexdigest() source = None if isinstance(content, Fragment): source = content.source with db.conn: - db.execute(sql, {"hash": hash, "content": content, "source": source}) + db.execute(sql, {"hash": hash_id, "content": content, "source": source}) return list( - db.query("select id from fragments where hash = :hash", {"hash": hash}) + db.query("select id from fragments where hash = :hash", {"hash": hash_id}) + )[0]["id"] + + +def ensure_tool(db, tool): + sql = """ + insert into tools (hash, name, description, input_schema) + values (:hash, :name, :description, :input_schema) + on conflict(hash) do nothing + """ + with db.conn: + db.execute( + sql, + { + "hash": tool.hash(), + "name": tool.name, + "description": tool.description, + "input_schema": json.dumps(tool.input_schema), + }, + ) + return list( + db.query("select id from tools where hash = :hash", {"hash": tool.hash()}) )[0]["id"]