Log tool_instances to database (#1098)

* Log tool_instances to database, closes #1089
* Tested for both sync and async models
This commit is contained in:
Simon Willison 2025-05-26 21:01:55 -07:00 committed by GitHub
parent c9e8593095
commit e4ecb86421
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 383 additions and 46 deletions

View file

@ -397,13 +397,14 @@ CREATE TABLE [tool_calls] (
[arguments] TEXT,
[tool_call_id] TEXT
);
CREATE TABLE [tool_results] (
CREATE TABLE "tool_results" (
[id] INTEGER PRIMARY KEY,
[response_id] TEXT REFERENCES [responses]([id]),
[tool_id] INTEGER REFERENCES [tools]([id]),
[name] TEXT,
[output] TEXT,
[tool_call_id] TEXT
[tool_call_id] TEXT,
[instance_id] INTEGER REFERENCES [tool_instances]([id])
);
```
<!-- [[[end]]] -->

View file

@ -373,3 +373,20 @@ def m017_tools_tables(db):
@migration
def m017_tools_plugin(db):
db["tools"].add_column("plugin")
@migration
def m018_tool_instances(db):
# Used to track instances of Toolbox classes that may be
# used multiple times by different tools
db["tool_instances"].create(
{
"id": int,
"plugin": str,
"name": str,
"arguments": str,
},
pk="id",
)
# We record which instance was used only on the results
db["tool_results"].add_column("instance_id", fk="tool_instances")

View file

@ -175,6 +175,29 @@ def _get_arguments_input_schema(function, name):
class Toolbox:
_blocked = ("method_tools", "introspect_methods", "methods")
name: Optional[str] = None
instance_id: Optional[int] = None
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
original_init = cls.__init__
def wrapped_init(self, *args, **kwargs):
sig = inspect.signature(original_init)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
self._config = {
name: value
for name, value in bound.arguments.items()
if name != "self"
and sig.parameters[name].kind
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
original_init(self, *args, **kwargs)
cls.__init__ = wrapped_init
@classmethod
def methods(cls):
@ -197,10 +220,12 @@ class Toolbox:
method = getattr(self, method_name)
# The attribute must be a bound method, i.e. inspect.ismethod()
if callable(method) and inspect.ismethod(method):
yield Tool.function(
tool = Tool.function(
method,
name="{}_{}".format(self.__class__.__name__, method_name),
)
tool.plugin = getattr(self, "plugin", None)
yield tool
@classmethod
def introspect_methods(cls):
@ -235,6 +260,7 @@ class ToolResult:
name: str
output: str
tool_call_id: Optional[str] = None
instance: Optional[Toolbox] = None
class CancelToolCall(Exception):
@ -834,6 +860,21 @@ class _BaseResponse:
}
)
for tool_result in self.prompt.tool_results:
instance_id = None
if tool_result.instance:
if not tool_result.instance.instance_id:
tool_result.instance.instance_id = (
db["tool_instances"]
.insert(
{
"plugin": tool.plugin,
"name": tool.name.split("_")[0],
"arguments": json.dumps(tool_result.instance._config),
}
)
.last_pk
)
instance_id = tool_result.instance.instance_id
db["tool_results"].insert(
{
"response_id": response_id,
@ -841,6 +882,7 @@ class _BaseResponse:
"name": tool_result.name,
"output": tool_result.output,
"tool_call_id": tool_result.tool_call_id,
"instance_id": instance_id,
}
)
@ -919,6 +961,7 @@ class Response(_BaseResponse):
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
instance=_get_instance(tool.implementation),
)
if after_call:
@ -1078,6 +1121,7 @@ class AsyncResponse(_BaseResponse):
name=tc.name,
output=output,
tool_call_id=tc.tool_call_id,
instance=_get_instance(tool.implementation),
)
# after_call inside the task
@ -1111,6 +1155,7 @@ class AsyncResponse(_BaseResponse):
name=tc.name,
output=output,
tool_call_id=tc.tool_call_id,
instance=_get_instance(tool.implementation),
)
if after_call:
@ -1844,3 +1889,9 @@ def _remove_titles_recursively(obj):
# Process each item in lists
for item in obj:
_remove_titles_recursively(item)
def _get_instance(implementation):
if hasattr(implementation, "__self__"):
return implementation.__self__
return None

View file

@ -487,40 +487,51 @@ def test_register_tools(tmpdir, logs_db):
plugins.pm.unregister(name="ToolsPlugin")
class Memory(llm.Toolbox):
_memory = None
def _get_memory(self):
if self._memory is None:
self._memory = {}
return self._memory
def set(self, key: str, value: str):
"Set something as a key"
self._get_memory()[key] = value
def get(self, key: str):
"Get something from a key"
return self._get_memory().get(key) or ""
def append(self, key: str, value: str):
"Append something as a key"
memory = self._get_memory()
memory[key] = (memory.get(key) or "") + "\n" + value
def keys(self):
"Return a list of keys"
return list(self._get_memory().keys())
class Filesystem(llm.Toolbox):
def __init__(self, path: str):
self.path = path
async def list_files(self):
# async here just to confirm that works
return [str(item) for item in pathlib.Path(self.path).glob("*")]
class ToolboxPlugin:
__name__ = "ToolboxPlugin"
@hookimpl
def register_tools(self, register):
register(Memory)
register(Filesystem)
def test_register_toolbox(tmpdir, logs_db):
class Memory(llm.Toolbox):
_memory = None
def _get_memory(self):
if self._memory is None:
self._memory = {}
return self._memory
def set(self, key: str, value: str):
"Set something as a key"
self._get_memory()[key] = value
def get(self, key: str):
"Get something from a key"
return self._get_memory().get(key) or ""
def append(self, key: str, value: str):
"Append something as a key"
memory = self._get_memory()
memory[key] = (memory.get(key) or "") + "\n" + value
def keys(self):
"Return a list of keys"
return list(self._get_memory().keys())
class Filesystem(llm.Toolbox):
def __init__(self, path: str):
self.path = path
async def list_files(self):
# async here just to confirm that works
return [str(item) for item in pathlib.Path(self.path).glob("*")]
# Test the Python API
model = llm.get_model("echo")
memory = Memory()
@ -578,15 +589,6 @@ def test_register_toolbox(tmpdir, logs_db):
]
# Now register them with a plugin and use it through the CLI
class ToolboxPlugin:
__name__ = "ToolboxPlugin"
@hookimpl
def register_tools(self, register):
register(Memory)
register(Filesystem)
try:
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
tools = llm.get_tools()
@ -739,11 +741,171 @@ def test_register_toolbox(tmpdir, logs_db):
"tool_call_id": None,
}
]
# Test the logging worked
rows = list(logs_db.query(TOOL_RESULTS_SQL))
# JSON decode things in rows
for row in rows:
row["tool_calls"] = json.loads(row["tool_calls"])
row["tool_results"] = json.loads(row["tool_results"])
assert rows == [
{
"model": "echo",
"tool_calls": [
{
"name": "Memory_set",
"arguments": '{"key": "hi", "value": "two"}',
},
{"name": "Memory_get", "arguments": '{"key": "hi"}'},
],
"tool_results": [],
},
{
"model": "echo",
"tool_calls": [],
"tool_results": [
{
"name": "Memory_set",
"output": "null",
"instance": {
"name": "Memory",
"plugin": "ToolboxPlugin",
"arguments": "{}",
},
},
{
"name": "Memory_get",
"output": "two",
"instance": {
"name": "Memory",
"plugin": "ToolboxPlugin",
"arguments": "{}",
},
},
],
},
{
"model": "echo",
"tool_calls": [{"name": "Filesystem_list_files", "arguments": "{}"}],
"tool_results": [],
},
{
"model": "echo",
"tool_calls": [],
"tool_results": [
{
"name": "Filesystem_list_files",
"output": json.dumps([str(other_path)]),
"instance": {
"name": "Filesystem",
"plugin": "ToolboxPlugin",
"arguments": json.dumps({"path": str(my_dir2)}),
},
}
],
},
]
finally:
plugins.pm.unregister(name="ToolboxPlugin")
def test_toolbox_logging_async(logs_db, tmpdir):
path = pathlib.Path(tmpdir / "path")
path.mkdir()
runner = CliRunner()
try:
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
# Run Memory and Filesystem tests --async
result = runner.invoke(
cli.cli,
[
"prompt",
"--async",
"-T",
"Memory",
"--tool",
"Filesystem({})".format(json.dumps(str(path))),
json.dumps(
{
"tool_calls": [
{
"name": "Memory_set",
"arguments": {"key": "hi", "value": "two"},
},
{"name": "Memory_get", "arguments": {"key": "hi"}},
{"name": "Filesystem_list_files"},
]
}
),
"-m",
"echo",
],
)
assert result.exit_code == 0
tool_results = json.loads(
"[" + result.output.split('"tool_results": [')[1].rsplit("]", 1)[0] + "]"
)
assert tool_results == [
{"name": "Memory_set", "output": "null", "tool_call_id": None},
{"name": "Memory_get", "output": "two", "tool_call_id": None},
{"name": "Filesystem_list_files", "output": "[]", "tool_call_id": None},
]
finally:
plugins.pm.unregister(name="ToolboxPlugin")
# Check the database
rows = list(logs_db.query(TOOL_RESULTS_SQL))
# JSON decode things in rows
for row in rows:
row["tool_calls"] = json.loads(row["tool_calls"])
row["tool_results"] = json.loads(row["tool_results"])
assert rows == [
{
"model": "echo",
"tool_calls": [
{"name": "Memory_set", "arguments": '{"key": "hi", "value": "two"}'},
{"name": "Memory_get", "arguments": '{"key": "hi"}'},
{"name": "Filesystem_list_files", "arguments": "{}"},
],
"tool_results": [],
},
{
"model": "echo",
"tool_calls": [],
"tool_results": [
{
"name": "Memory_set",
"output": "null",
"instance": {
"name": "Filesystem",
"plugin": "ToolboxPlugin",
"arguments": "{}",
},
},
{
"name": "Memory_get",
"output": "two",
"instance": {
"name": "Filesystem",
"plugin": "ToolboxPlugin",
"arguments": "{}",
},
},
{
"name": "Filesystem_list_files",
"output": "[]",
"instance": {
"name": "Filesystem",
"plugin": "ToolboxPlugin",
"arguments": json.dumps({"path": str(path)}),
},
},
],
},
]
def test_plugins_command():
runner = CliRunner()
result = runner.invoke(cli.cli, ["plugins"])
@ -767,3 +929,59 @@ def test_plugins_command():
"hooks": ["register_embedding_models", "register_models"],
},
]
TOOL_RESULTS_SQL = """
-- First, create ordered subqueries for tool_calls and tool_results
with ordered_tool_calls as (
select
tc.response_id,
json_group_array(
json_object(
'name', tc.name,
'arguments', tc.arguments
)
) as tool_calls_json
from (
select * from tool_calls order by id
) tc
where tc.id is not null
group by tc.response_id
),
ordered_tool_results as (
select
tr.response_id,
json_group_array(
json_object(
'name', tr.name,
'output', tr.output,
'instance', case
when ti.id is not null then json_object(
'name', ti.name,
'plugin', ti.plugin,
'arguments', ti.arguments
)
else null
end
)
) as tool_results_json
from (
select distinct tr.*, ti.id as ti_id, ti.name as ti_name,
ti.plugin, ti.arguments as ti_arguments
from tool_results tr
left join tool_instances ti on tr.instance_id = ti.id
order by tr.id
) tr
left join tool_instances ti on tr.instance_id = ti.id
where tr.id is not null
group by tr.response_id
)
select
r.model,
coalesce(otc.tool_calls_json, '[]') as tool_calls,
coalesce(otr.tool_results_json, '[]') as tool_results
from responses r
left join ordered_tool_calls otc on r.id = otc.response_id
left join ordered_tool_results otr on r.id = otr.response_id
group by r.id, r.model
order by r.id"""

View file

@ -9,7 +9,7 @@ from llm.utils import (
truncate_string,
monotonic_ulid,
)
from llm import get_key
from llm import get_key, Toolbox
@pytest.mark.parametrize(
@ -466,3 +466,53 @@ def test_get_key(user_path, monkeypatch):
def test_monotonic_ulids():
ulids = [monotonic_ulid() for i in range(1000)]
assert ulids == sorted(ulids)
def test_toolbox_config_capture():
"""Test that Toolbox captures __init__ parameters in _config"""
# Single positional arg
class Tool1(Toolbox):
def __init__(self, value):
pass
assert Tool1(42)._config == {"value": 42}
# Multiple positional args
class Tool2(Toolbox):
def __init__(self, a, b, c):
pass
assert Tool2(1, 2, 3)._config == {"a": 1, "b": 2, "c": 3}
# Keyword args with defaults
class Tool3(Toolbox):
def __init__(self, name="default", count=10):
pass
assert Tool3()._config == {"name": "default", "count": 10}
assert Tool3(name="custom", count=20)._config == {"name": "custom", "count": 20}
# Mixed args
class Tool4(Toolbox):
def __init__(self, required, optional="default"):
pass
assert Tool4("hello")._config == {"required": "hello", "optional": "default"}
assert Tool4("world", optional="custom")._config == {
"required": "world",
"optional": "custom",
}
# Var args excluded
class Tool5(Toolbox):
def __init__(self, regular, *args, **kwargs):
pass
assert Tool5("test", 1, 2, extra="value")._config == {"regular": "test"}
# No init
class Tool6(Toolbox):
pass
assert Tool6()._config == {}