Record which plugin a tool came from, including in DB - refs #1020

This commit is contained in:
Simon Willison 2025-05-23 15:44:12 -07:00
parent cb94047b46
commit d2886d4692
6 changed files with 40 additions and 5 deletions

View file

@ -354,7 +354,8 @@ CREATE TABLE [tools] (
[hash] TEXT,
[name] TEXT,
[description] TEXT,
[input_schema] TEXT
[input_schema] TEXT,
[plugin] TEXT
);
CREATE TABLE [tool_responses] (
[tool_id] INTEGER REFERENCES [tools]([id]),

View file

@ -139,6 +139,9 @@ def get_tools() -> Dict[str, Tool]:
load_plugins()
tools: Dict[str, Tool] = {}
# Variable to track current plugin name
current_plugin_name = None
def register(
tool_or_function: Union[Tool, Callable[..., Any]],
name: Optional[str] = None,
@ -148,6 +151,9 @@ def get_tools() -> Dict[str, Tool]:
tool_or_function = Tool.function(tool_or_function, name=name)
tool = cast(Tool, tool_or_function)
if current_plugin_name:
tool.plugin = current_plugin_name
prefix = name or tool.name
suffix = 0
candidate = prefix
@ -159,7 +165,16 @@ def get_tools() -> Dict[str, Tool]:
tools[candidate] = tool
pm.hook.register_tools(register=register)
# Call each plugin's register_tools hook individually to track current_plugin_name
for plugin in pm.get_plugins():
current_plugin_name = pm.get_name(plugin)
hook_caller = pm.hook.register_tools
plugin_impls = [
impl for impl in hook_caller.get_hookimpls() if impl.plugin is plugin
]
for impl in plugin_impls:
impl.function(register=register)
return tools

View file

@ -2337,6 +2337,7 @@ def tools_list(json_, python_tools):
name: {
"description": tool.description,
"arguments": tool.input_schema,
"plugin": tool.plugin,
}
for name, tool in tools.items()
},
@ -2348,7 +2349,13 @@ def tools_list(json_, python_tools):
sig = "()"
if tool.implementation:
sig = str(inspect.signature(tool.implementation))
click.echo("{}{}".format(name, sig))
click.echo(
"{}{}{}".format(
name,
sig,
" (plugin: {})".format(tool.plugin) if tool.plugin else "",
)
)
if tool.description:
click.echo(textwrap.indent(tool.description, " "))

View file

@ -368,3 +368,8 @@ def m017_tools_tables(db):
("tool_id", "tools", "id"),
),
)
@migration
def m017_tools_plugin(db):
db["tools"].add_column("plugin")

View file

@ -114,6 +114,7 @@ class Tool:
description: Optional[str] = None
input_schema: Dict = field(default_factory=dict)
implementation: Optional[Callable] = None
plugin: Optional[str] = None # plugin tool came from, e.g. 'llm_tools_sqlite'
def __post_init__(self):
# Convert Pydantic model to JSON schema if needed
@ -137,6 +138,8 @@ class Tool:
"description": self.description,
"input_schema": self.input_schema,
}
if self.plugin:
to_hash["plugin"] = self.plugin
return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest()
@classmethod

View file

@ -226,6 +226,7 @@ def test_register_tools(tmpdir):
"type": "object",
},
implementation=upper,
plugin="ToolsPlugin",
),
"count_chars": llm.Tool(
name="count_chars",
@ -239,6 +240,7 @@ def test_register_tools(tmpdir):
"type": "object",
},
implementation=count_character_in_word,
plugin="ToolsPlugin",
),
}
# Test the CLI command
@ -246,9 +248,9 @@ def test_register_tools(tmpdir):
result = runner.invoke(cli.cli, ["tools", "list"])
assert result.exit_code == 0
assert result.output == (
"upper(text: str) -> str\n"
"upper(text: str) -> str (plugin: ToolsPlugin)\n"
" Convert text to uppercase.\n"
"count_chars(text: str, character: str) -> int\n"
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n"
" Count the number of occurrences of a character in a word.\n"
)
# And --json
@ -262,6 +264,7 @@ def test_register_tools(tmpdir):
"required": ["text"],
"type": "object",
},
"plugin": "ToolsPlugin",
},
"count_chars": {
"description": "Count the number of occurrences of a character in a word.",
@ -273,6 +276,7 @@ def test_register_tools(tmpdir):
"required": ["text", "character"],
"type": "object",
},
"plugin": "ToolsPlugin",
},
}
# And test the --tools option