mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-26 07:54:45 +00:00
Log tool_instances to database (#1098)
* Log tool_instances to database, closes #1089 * Tested for both sync and async models
This commit is contained in:
parent
c9e8593095
commit
e4ecb86421
5 changed files with 383 additions and 46 deletions
|
|
@ -397,13 +397,14 @@ CREATE TABLE [tool_calls] (
|
|||
[arguments] TEXT,
|
||||
[tool_call_id] TEXT
|
||||
);
|
||||
CREATE TABLE [tool_results] (
|
||||
CREATE TABLE "tool_results" (
|
||||
[id] INTEGER PRIMARY KEY,
|
||||
[response_id] TEXT REFERENCES [responses]([id]),
|
||||
[tool_id] INTEGER REFERENCES [tools]([id]),
|
||||
[name] TEXT,
|
||||
[output] TEXT,
|
||||
[tool_call_id] TEXT
|
||||
[tool_call_id] TEXT,
|
||||
[instance_id] INTEGER REFERENCES [tool_instances]([id])
|
||||
);
|
||||
```
|
||||
<!-- [[[end]]] -->
|
||||
|
|
|
|||
|
|
@ -373,3 +373,20 @@ def m017_tools_tables(db):
|
|||
@migration
|
||||
def m017_tools_plugin(db):
|
||||
db["tools"].add_column("plugin")
|
||||
|
||||
|
||||
@migration
|
||||
def m018_tool_instances(db):
|
||||
# Used to track instances of Toolbox classes that may be
|
||||
# used multiple times by different tools
|
||||
db["tool_instances"].create(
|
||||
{
|
||||
"id": int,
|
||||
"plugin": str,
|
||||
"name": str,
|
||||
"arguments": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
# We record which instance was used only on the results
|
||||
db["tool_results"].add_column("instance_id", fk="tool_instances")
|
||||
|
|
|
|||
|
|
@ -175,6 +175,29 @@ def _get_arguments_input_schema(function, name):
|
|||
class Toolbox:
|
||||
_blocked = ("method_tools", "introspect_methods", "methods")
|
||||
name: Optional[str] = None
|
||||
instance_id: Optional[int] = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
original_init = cls.__init__
|
||||
|
||||
def wrapped_init(self, *args, **kwargs):
|
||||
sig = inspect.signature(original_init)
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
self._config = {
|
||||
name: value
|
||||
for name, value in bound.arguments.items()
|
||||
if name != "self"
|
||||
and sig.parameters[name].kind
|
||||
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
||||
}
|
||||
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
cls.__init__ = wrapped_init
|
||||
|
||||
@classmethod
|
||||
def methods(cls):
|
||||
|
|
@ -197,10 +220,12 @@ class Toolbox:
|
|||
method = getattr(self, method_name)
|
||||
# The attribute must be a bound method, i.e. inspect.ismethod()
|
||||
if callable(method) and inspect.ismethod(method):
|
||||
yield Tool.function(
|
||||
tool = Tool.function(
|
||||
method,
|
||||
name="{}_{}".format(self.__class__.__name__, method_name),
|
||||
)
|
||||
tool.plugin = getattr(self, "plugin", None)
|
||||
yield tool
|
||||
|
||||
@classmethod
|
||||
def introspect_methods(cls):
|
||||
|
|
@ -235,6 +260,7 @@ class ToolResult:
|
|||
name: str
|
||||
output: str
|
||||
tool_call_id: Optional[str] = None
|
||||
instance: Optional[Toolbox] = None
|
||||
|
||||
|
||||
class CancelToolCall(Exception):
|
||||
|
|
@ -834,6 +860,21 @@ class _BaseResponse:
|
|||
}
|
||||
)
|
||||
for tool_result in self.prompt.tool_results:
|
||||
instance_id = None
|
||||
if tool_result.instance:
|
||||
if not tool_result.instance.instance_id:
|
||||
tool_result.instance.instance_id = (
|
||||
db["tool_instances"]
|
||||
.insert(
|
||||
{
|
||||
"plugin": tool.plugin,
|
||||
"name": tool.name.split("_")[0],
|
||||
"arguments": json.dumps(tool_result.instance._config),
|
||||
}
|
||||
)
|
||||
.last_pk
|
||||
)
|
||||
instance_id = tool_result.instance.instance_id
|
||||
db["tool_results"].insert(
|
||||
{
|
||||
"response_id": response_id,
|
||||
|
|
@ -841,6 +882,7 @@ class _BaseResponse:
|
|||
"name": tool_result.name,
|
||||
"output": tool_result.output,
|
||||
"tool_call_id": tool_result.tool_call_id,
|
||||
"instance_id": instance_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -919,6 +961,7 @@ class Response(_BaseResponse):
|
|||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
)
|
||||
|
||||
if after_call:
|
||||
|
|
@ -1078,6 +1121,7 @@ class AsyncResponse(_BaseResponse):
|
|||
name=tc.name,
|
||||
output=output,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
)
|
||||
|
||||
# after_call inside the task
|
||||
|
|
@ -1111,6 +1155,7 @@ class AsyncResponse(_BaseResponse):
|
|||
name=tc.name,
|
||||
output=output,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
)
|
||||
|
||||
if after_call:
|
||||
|
|
@ -1844,3 +1889,9 @@ def _remove_titles_recursively(obj):
|
|||
# Process each item in lists
|
||||
for item in obj:
|
||||
_remove_titles_recursively(item)
|
||||
|
||||
|
||||
def _get_instance(implementation):
|
||||
if hasattr(implementation, "__self__"):
|
||||
return implementation.__self__
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -487,40 +487,51 @@ def test_register_tools(tmpdir, logs_db):
|
|||
plugins.pm.unregister(name="ToolsPlugin")
|
||||
|
||||
|
||||
class Memory(llm.Toolbox):
|
||||
_memory = None
|
||||
|
||||
def _get_memory(self):
|
||||
if self._memory is None:
|
||||
self._memory = {}
|
||||
return self._memory
|
||||
|
||||
def set(self, key: str, value: str):
|
||||
"Set something as a key"
|
||||
self._get_memory()[key] = value
|
||||
|
||||
def get(self, key: str):
|
||||
"Get something from a key"
|
||||
return self._get_memory().get(key) or ""
|
||||
|
||||
def append(self, key: str, value: str):
|
||||
"Append something as a key"
|
||||
memory = self._get_memory()
|
||||
memory[key] = (memory.get(key) or "") + "\n" + value
|
||||
|
||||
def keys(self):
|
||||
"Return a list of keys"
|
||||
return list(self._get_memory().keys())
|
||||
|
||||
|
||||
class Filesystem(llm.Toolbox):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
async def list_files(self):
|
||||
# async here just to confirm that works
|
||||
return [str(item) for item in pathlib.Path(self.path).glob("*")]
|
||||
|
||||
|
||||
class ToolboxPlugin:
|
||||
__name__ = "ToolboxPlugin"
|
||||
|
||||
@hookimpl
|
||||
def register_tools(self, register):
|
||||
register(Memory)
|
||||
register(Filesystem)
|
||||
|
||||
|
||||
def test_register_toolbox(tmpdir, logs_db):
|
||||
class Memory(llm.Toolbox):
|
||||
_memory = None
|
||||
|
||||
def _get_memory(self):
|
||||
if self._memory is None:
|
||||
self._memory = {}
|
||||
return self._memory
|
||||
|
||||
def set(self, key: str, value: str):
|
||||
"Set something as a key"
|
||||
self._get_memory()[key] = value
|
||||
|
||||
def get(self, key: str):
|
||||
"Get something from a key"
|
||||
return self._get_memory().get(key) or ""
|
||||
|
||||
def append(self, key: str, value: str):
|
||||
"Append something as a key"
|
||||
memory = self._get_memory()
|
||||
memory[key] = (memory.get(key) or "") + "\n" + value
|
||||
|
||||
def keys(self):
|
||||
"Return a list of keys"
|
||||
return list(self._get_memory().keys())
|
||||
|
||||
class Filesystem(llm.Toolbox):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
async def list_files(self):
|
||||
# async here just to confirm that works
|
||||
return [str(item) for item in pathlib.Path(self.path).glob("*")]
|
||||
|
||||
# Test the Python API
|
||||
model = llm.get_model("echo")
|
||||
memory = Memory()
|
||||
|
|
@ -578,15 +589,6 @@ def test_register_toolbox(tmpdir, logs_db):
|
|||
]
|
||||
|
||||
# Now register them with a plugin and use it through the CLI
|
||||
|
||||
class ToolboxPlugin:
|
||||
__name__ = "ToolboxPlugin"
|
||||
|
||||
@hookimpl
|
||||
def register_tools(self, register):
|
||||
register(Memory)
|
||||
register(Filesystem)
|
||||
|
||||
try:
|
||||
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
|
||||
tools = llm.get_tools()
|
||||
|
|
@ -739,11 +741,171 @@ def test_register_toolbox(tmpdir, logs_db):
|
|||
"tool_call_id": None,
|
||||
}
|
||||
]
|
||||
# Test the logging worked
|
||||
rows = list(logs_db.query(TOOL_RESULTS_SQL))
|
||||
# JSON decode things in rows
|
||||
for row in rows:
|
||||
row["tool_calls"] = json.loads(row["tool_calls"])
|
||||
row["tool_results"] = json.loads(row["tool_results"])
|
||||
assert rows == [
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"arguments": '{"key": "hi", "value": "two"}',
|
||||
},
|
||||
{"name": "Memory_get", "arguments": '{"key": "hi"}'},
|
||||
],
|
||||
"tool_results": [],
|
||||
},
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [],
|
||||
"tool_results": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"output": "null",
|
||||
"instance": {
|
||||
"name": "Memory",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Memory_get",
|
||||
"output": "two",
|
||||
"instance": {
|
||||
"name": "Memory",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [{"name": "Filesystem_list_files", "arguments": "{}"}],
|
||||
"tool_results": [],
|
||||
},
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [],
|
||||
"tool_results": [
|
||||
{
|
||||
"name": "Filesystem_list_files",
|
||||
"output": json.dumps([str(other_path)]),
|
||||
"instance": {
|
||||
"name": "Filesystem",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": json.dumps({"path": str(my_dir2)}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
finally:
|
||||
plugins.pm.unregister(name="ToolboxPlugin")
|
||||
|
||||
|
||||
def test_toolbox_logging_async(logs_db, tmpdir):
|
||||
path = pathlib.Path(tmpdir / "path")
|
||||
path.mkdir()
|
||||
runner = CliRunner()
|
||||
try:
|
||||
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
|
||||
|
||||
# Run Memory and Filesystem tests --async
|
||||
result = runner.invoke(
|
||||
cli.cli,
|
||||
[
|
||||
"prompt",
|
||||
"--async",
|
||||
"-T",
|
||||
"Memory",
|
||||
"--tool",
|
||||
"Filesystem({})".format(json.dumps(str(path))),
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"arguments": {"key": "hi", "value": "two"},
|
||||
},
|
||||
{"name": "Memory_get", "arguments": {"key": "hi"}},
|
||||
{"name": "Filesystem_list_files"},
|
||||
]
|
||||
}
|
||||
),
|
||||
"-m",
|
||||
"echo",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
tool_results = json.loads(
|
||||
"[" + result.output.split('"tool_results": [')[1].rsplit("]", 1)[0] + "]"
|
||||
)
|
||||
assert tool_results == [
|
||||
{"name": "Memory_set", "output": "null", "tool_call_id": None},
|
||||
{"name": "Memory_get", "output": "two", "tool_call_id": None},
|
||||
{"name": "Filesystem_list_files", "output": "[]", "tool_call_id": None},
|
||||
]
|
||||
finally:
|
||||
plugins.pm.unregister(name="ToolboxPlugin")
|
||||
|
||||
# Check the database
|
||||
rows = list(logs_db.query(TOOL_RESULTS_SQL))
|
||||
# JSON decode things in rows
|
||||
for row in rows:
|
||||
row["tool_calls"] = json.loads(row["tool_calls"])
|
||||
row["tool_results"] = json.loads(row["tool_results"])
|
||||
assert rows == [
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [
|
||||
{"name": "Memory_set", "arguments": '{"key": "hi", "value": "two"}'},
|
||||
{"name": "Memory_get", "arguments": '{"key": "hi"}'},
|
||||
{"name": "Filesystem_list_files", "arguments": "{}"},
|
||||
],
|
||||
"tool_results": [],
|
||||
},
|
||||
{
|
||||
"model": "echo",
|
||||
"tool_calls": [],
|
||||
"tool_results": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"output": "null",
|
||||
"instance": {
|
||||
"name": "Filesystem",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Memory_get",
|
||||
"output": "two",
|
||||
"instance": {
|
||||
"name": "Filesystem",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Filesystem_list_files",
|
||||
"output": "[]",
|
||||
"instance": {
|
||||
"name": "Filesystem",
|
||||
"plugin": "ToolboxPlugin",
|
||||
"arguments": json.dumps({"path": str(path)}),
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_plugins_command():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli.cli, ["plugins"])
|
||||
|
|
@ -767,3 +929,59 @@ def test_plugins_command():
|
|||
"hooks": ["register_embedding_models", "register_models"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
TOOL_RESULTS_SQL = """
|
||||
-- First, create ordered subqueries for tool_calls and tool_results
|
||||
with ordered_tool_calls as (
|
||||
select
|
||||
tc.response_id,
|
||||
json_group_array(
|
||||
json_object(
|
||||
'name', tc.name,
|
||||
'arguments', tc.arguments
|
||||
)
|
||||
) as tool_calls_json
|
||||
from (
|
||||
select * from tool_calls order by id
|
||||
) tc
|
||||
where tc.id is not null
|
||||
group by tc.response_id
|
||||
),
|
||||
ordered_tool_results as (
|
||||
select
|
||||
tr.response_id,
|
||||
json_group_array(
|
||||
json_object(
|
||||
'name', tr.name,
|
||||
'output', tr.output,
|
||||
'instance', case
|
||||
when ti.id is not null then json_object(
|
||||
'name', ti.name,
|
||||
'plugin', ti.plugin,
|
||||
'arguments', ti.arguments
|
||||
)
|
||||
else null
|
||||
end
|
||||
)
|
||||
) as tool_results_json
|
||||
from (
|
||||
select distinct tr.*, ti.id as ti_id, ti.name as ti_name,
|
||||
ti.plugin, ti.arguments as ti_arguments
|
||||
from tool_results tr
|
||||
left join tool_instances ti on tr.instance_id = ti.id
|
||||
order by tr.id
|
||||
) tr
|
||||
left join tool_instances ti on tr.instance_id = ti.id
|
||||
where tr.id is not null
|
||||
group by tr.response_id
|
||||
)
|
||||
select
|
||||
r.model,
|
||||
coalesce(otc.tool_calls_json, '[]') as tool_calls,
|
||||
coalesce(otr.tool_results_json, '[]') as tool_results
|
||||
from responses r
|
||||
left join ordered_tool_calls otc on r.id = otc.response_id
|
||||
left join ordered_tool_results otr on r.id = otr.response_id
|
||||
group by r.id, r.model
|
||||
order by r.id"""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from llm.utils import (
|
|||
truncate_string,
|
||||
monotonic_ulid,
|
||||
)
|
||||
from llm import get_key
|
||||
from llm import get_key, Toolbox
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -466,3 +466,53 @@ def test_get_key(user_path, monkeypatch):
|
|||
def test_monotonic_ulids():
|
||||
ulids = [monotonic_ulid() for i in range(1000)]
|
||||
assert ulids == sorted(ulids)
|
||||
|
||||
|
||||
def test_toolbox_config_capture():
|
||||
"""Test that Toolbox captures __init__ parameters in _config"""
|
||||
|
||||
# Single positional arg
|
||||
class Tool1(Toolbox):
|
||||
def __init__(self, value):
|
||||
pass
|
||||
|
||||
assert Tool1(42)._config == {"value": 42}
|
||||
|
||||
# Multiple positional args
|
||||
class Tool2(Toolbox):
|
||||
def __init__(self, a, b, c):
|
||||
pass
|
||||
|
||||
assert Tool2(1, 2, 3)._config == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
# Keyword args with defaults
|
||||
class Tool3(Toolbox):
|
||||
def __init__(self, name="default", count=10):
|
||||
pass
|
||||
|
||||
assert Tool3()._config == {"name": "default", "count": 10}
|
||||
assert Tool3(name="custom", count=20)._config == {"name": "custom", "count": 20}
|
||||
|
||||
# Mixed args
|
||||
class Tool4(Toolbox):
|
||||
def __init__(self, required, optional="default"):
|
||||
pass
|
||||
|
||||
assert Tool4("hello")._config == {"required": "hello", "optional": "default"}
|
||||
assert Tool4("world", optional="custom")._config == {
|
||||
"required": "world",
|
||||
"optional": "custom",
|
||||
}
|
||||
|
||||
# Var args excluded
|
||||
class Tool5(Toolbox):
|
||||
def __init__(self, regular, *args, **kwargs):
|
||||
pass
|
||||
|
||||
assert Tool5("test", 1, 2, extra="value")._config == {"regular": "test"}
|
||||
|
||||
# No init
|
||||
class Tool6(Toolbox):
|
||||
pass
|
||||
|
||||
assert Tool6()._config == {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue