Toolbox.add_tool(), prepare() and prepare_async() methods

Closes #1111
This commit is contained in:
Simon Willison 2025-08-11 13:19:31 -07:00 committed by GitHub
parent ef3192b44d
commit 08094082f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 234 additions and 72 deletions

View file

@ -640,7 +640,7 @@ Commands:
(help-tools-list)=
#### llm tools list --help
```
Usage: llm tools list [OPTIONS]
Usage: llm tools list [OPTIONS] [TOOL_DEFS]...
List available tools that have been provided by plugins

View file

@ -108,13 +108,13 @@ def register_tools(register):
Tools can also be implemented as classes, as described in {ref}`Toolbox classes <python-api-toolbox>` in the Python API documentation.
You can register classes like the `Memory` example from there by passing the class (_not_ an instance of the class) to `register()`:
You can register classes like the `Memory` example {ref}`from here <python-api-toolbox>` by passing the class (_not_ an instance of the class) to `register()`:
```python
import llm
class Memory(llm.Toolbox):
...
# Copy implementation from the Python API documentation
@llm.hookimpl
def register_tools(register):

View file

@ -277,6 +277,31 @@ print(conversation.chain("Print current name", after_call=print).text())
See the {ref}`register_tools() plugin hook documentation <plugin-hooks-register-tools>` for an example of this tool in action as a CLI plugin.
(python-api-tools-dynamic)=
#### Dynamic toolboxes
Sometimes you may need to register additional tools against a toolbox after it has been created - for example if you are implementing an MCP plugin where the toolbox needs to consult the MCP server to discover what tools are available.
You can use the `toolbox.add_tool(function_or_tool)` method to add a new tool to an existing toolbox.
This can be passed a `llm.Tool` instance or a function that will be converted into a tool automatically.
If you want your function to be able to access the toolbox instance itself as a `self` parameter, pass that function to `add_tool()` with the `pass_self=True` parameter:
```python
def my_function(self, arg1: str, arg2: int) -> str:
return f"Received {arg1} and {arg2} in {self}"
toolbox.add_tool(my_function, pass_self=True)
```
Without `pass_self=True` the function will be called with only its declared arguments, with no `self` parameter.
If your toolbox needs to run an additional command to figure out what it should register using `.add_tool()` you can implement a `prepare()` method on your toolbox class. This will be called once automatically when the toolbox is first used.
In asynchronous contexts the alternative method `await toolbox.prepare_async()` method will be called before the toolbox is used. You can implement this method on your subclass and use it to run asynchronous operations that discover tools to be registered using `self.add_tool()`.
If you want to prepare the class in this way such that it can be used in both synchronous and asynchronous contexts, implement both `prepare()` and `prepare_async()` methods.
(python-api-schemas)=
### Schemas

View file

@ -31,7 +31,7 @@ Applications built on top of LLMs suffer from a class of attacks called [prompt
Be very careful about which tools you enable when you potentially might be exposed to untrusted sources of content - web pages, GitHub issues posted by other people, email and messages that have been sent to you that could come from an attacker.
Watch out for the **lethal trifecta** of prompt injection exfiltration attacks. If your tool-enabled LLM has the following:
Watch out for [the lethal trifecta](https://simonwillison.net/2025/Jun/16/the-lethal-trifecta/) of prompt injection exfiltration attacks. If your tool-enabled LLM has the following:
- access to private data
- exposure to malicious instructions

View file

@ -2549,6 +2549,7 @@ def tools():
@tools.command(name="list")
@click.argument("tool_defs", nargs=-1)
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
@click.option(
"python_tools",
@ -2556,13 +2557,35 @@ def tools():
help="Python code block or file path defining functions to register as tools",
multiple=True,
)
def tools_list(json_, python_tools):
def tools_list(tool_defs, json_, python_tools):
"List available tools that have been provided by plugins"
tools = get_tools()
if python_tools:
for code_or_path in python_tools:
for tool in _tools_from_code(code_or_path):
def introspect_tools(toolbox_class):
methods = []
for tool in toolbox_class.method_tools():
methods.append(
{
"name": tool.name,
"description": tool.description,
"arguments": tool.input_schema,
"implementation": tool.implementation,
}
)
return methods
if tool_defs:
tools = {}
for tool in _gather_tools(tool_defs, python_tools):
if hasattr(tool, "name"):
tools[tool.name] = tool
else:
tools[tool.__class__.__name__] = tool
else:
tools = get_tools()
if python_tools:
for code_or_path in python_tools:
for tool in _tools_from_code(code_or_path):
tools[tool.name] = tool
output_tools = []
output_toolboxes = []
@ -2586,11 +2609,11 @@ def tools_list(json_, python_tools):
"name": name,
"tools": [
{
"name": method["name"],
"description": method["description"],
"arguments": method["arguments"],
"name": tool["name"],
"description": tool["description"],
"arguments": tool["arguments"],
}
for method in tool.introspect_methods()
for tool in introspect_tools(tool)
],
}
)
@ -2617,22 +2640,20 @@ def tools_list(json_, python_tools):
click.echo(textwrap.indent(tool.description.strip(), " ") + "\n")
for toolbox in toolbox_objects:
click.echo(toolbox.name + ":\n")
for method in toolbox.introspect_methods():
for tool in toolbox.method_tools():
sig = (
str(inspect.signature(method["implementation"]))
str(inspect.signature(tool.implementation))
.replace("(self, ", "(")
.replace("(self)", "()")
)
click.echo(
" {}{}\n".format(
method["name"],
tool.name,
sig,
)
)
if method["description"]:
click.echo(
textwrap.indent(method["description"].strip(), " ") + "\n"
)
if tool.description:
click.echo(textwrap.indent(tool.description.strip(), " ") + "\n")
@cli.group(

View file

@ -9,6 +9,7 @@ import httpx
from itertools import islice
import re
import time
from types import MethodType
from typing import (
Any,
AsyncGenerator,
@ -144,7 +145,7 @@ class Tool:
return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest()
@classmethod
def function(cls, function, name=None):
def function(cls, function, name=None, description=None):
"""
Turn a Python function into a Tool object by:
- Extracting the function name
@ -159,7 +160,7 @@ class Tool:
return cls(
name=name or function.__name__,
description=function.__doc__ or None,
description=description or function.__doc__ or None,
input_schema=_get_arguments_input_schema(function, name),
implementation=function,
)
@ -185,9 +186,20 @@ def _get_arguments_input_schema(function, name):
class Toolbox:
_blocked = ("method_tools", "introspect_methods", "methods")
name: Optional[str] = None
instance_id: Optional[int] = None
_blocked = (
"tools",
"add_tool",
"method_tools",
"__init_subclass__",
"prepare",
"prepare_async",
)
_extra_tools: List[Tool] = []
_config: Dict[str, Any] = {}
_prepared: bool = False
_async_prepared: bool = False
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@ -208,55 +220,69 @@ class Toolbox:
and sig.parameters[name].kind
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
}
self._extra_tools = []
original_init(self, *args, **kwargs)
cls.__init__ = wrapped_init
@classmethod
def methods(cls):
gathered = []
for name in dir(cls):
if name.startswith("_"):
def method_tools(cls) -> List[Tool]:
tools = []
for method_name in dir(cls):
if method_name.startswith("_") or method_name in cls._blocked:
continue
if name in cls._blocked:
continue
method = getattr(cls, name)
method = getattr(cls, method_name)
if callable(method):
gathered.append(method)
return gathered
def method_tools(self):
"Returns a list of llm.Tool() for each method"
for method_name in dir(self):
if method_name.startswith("_") or method_name in self._blocked:
continue
method = getattr(self, method_name)
# The attribute must be a bound method, i.e. inspect.ismethod()
if callable(method) and inspect.ismethod(method):
tool = Tool.function(
method,
name="{}_{}".format(self.__class__.__name__, method_name),
name="{}_{}".format(cls.__name__, method_name),
)
tools.append(tool)
return tools
def tools(self) -> Iterable[Tool]:
"Returns an llm.Tool() for each class method, plus any extras registered with add_tool()"
# method_tools() returns unbound methods, we need bound methods here:
for name in dir(self):
if name.startswith("_") or name in self._blocked:
continue
attr = getattr(self, name)
if callable(attr):
tool = Tool.function(attr, name=f"{self.__class__.__name__}_{name}")
tool.plugin = getattr(self, "plugin", None)
yield tool
yield from self._extra_tools
@classmethod
def introspect_methods(cls):
methods = []
for method in cls.methods():
arguments = _get_arguments_input_schema(method, method.__name__)
methods.append(
{
"name": method.__name__,
"description": (
method.__doc__.strip() if method.__doc__ is not None else None
),
"arguments": _ensure_dict_schema(arguments),
"implementation": method,
}
)
return methods
def add_tool(
self, tool_or_function: Union[Tool, Callable[..., Any]], pass_self: bool = False
):
"Add a tool to this toolbox"
def _upgrade(fn):
if pass_self:
return MethodType(fn, self)
return fn
if isinstance(tool_or_function, Tool):
self._extra_tools.append(tool_or_function)
elif callable(tool_or_function):
self._extra_tools.append(Tool.function(_upgrade(tool_or_function)))
else:
raise ValueError("Tool must be an instance of Tool or a callable function")
def prepare(self):
"""
Over-ride this to perform setup (and .add_tool() calls) before the toolbox is used.
Implement a similar prepare_async() method for async setup.
"""
pass
async def prepare_async(self):
"""
Over-ride this to perform async setup (and .add_tool() calls) before the toolbox is used.
"""
pass
@dataclass
@ -358,7 +384,7 @@ def _wrap_tools(tools: List[ToolDef]) -> List[Tool]:
if isinstance(tool, Tool):
wrapped_tools.append(tool)
elif isinstance(tool, Toolbox):
wrapped_tools.extend(tool.method_tools())
wrapped_tools.extend(tool.tools())
elif callable(tool):
wrapped_tools.append(Tool.function(tool))
else:
@ -1008,8 +1034,20 @@ class Response(_BaseResponse):
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# Run prepare() on all Toolbox instances that need it
instances_to_prepare: list[Toolbox] = []
for tool_to_prep in tools_by_name.values():
inst = _get_instance(tool_to_prep.implementation)
if isinstance(inst, Toolbox) and not getattr(inst, "_prepared", False):
instances_to_prepare.append(inst)
for inst in instances_to_prepare:
inst.prepare()
inst._prepared = True
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
tool: Optional[Tool] = tools_by_name.get(tool_call.name)
# Tool could be None if the tool was not found in the prompt tools,
# but we still call the before_call method:
if before_call:
@ -1195,11 +1233,24 @@ class AsyncResponse(_BaseResponse):
tool_calls_list = await self.tool_calls()
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# Run async prepare_async() on all Toolbox instances that need it
instances_to_prepare: list[Toolbox] = []
for tool_to_prep in tools_by_name.values():
inst = _get_instance(tool_to_prep.implementation)
if isinstance(inst, Toolbox) and not getattr(
inst, "_async_prepared", False
):
instances_to_prepare.append(inst)
for inst in instances_to_prepare:
await inst.prepare_async()
inst._async_prepared = True
indexed_results: List[tuple[int, ToolResult]] = []
async_tasks: List[asyncio.Task] = []
for idx, tc in enumerate(tool_calls_list):
tool = tools_by_name.get(tc.name)
tool: Optional[Tool] = tools_by_name.get(tc.name)
exception: Optional[Exception] = None
if tool is None:

View file

@ -641,7 +641,7 @@ def test_register_toolbox(tmpdir, logs_db):
"name": "Filesystem",
"tools": [
{
"name": "list_files",
"name": "Filesystem_list_files",
"description": None,
"arguments": {"properties": {}, "type": "object"},
}
@ -651,7 +651,7 @@ def test_register_toolbox(tmpdir, logs_db):
"name": "Memory",
"tools": [
{
"name": "append",
"name": "Memory_append",
"description": "Append something as a key",
"arguments": {
"properties": {
@ -663,7 +663,7 @@ def test_register_toolbox(tmpdir, logs_db):
},
},
{
"name": "get",
"name": "Memory_get",
"description": "Get something from a key",
"arguments": {
"properties": {"key": {"type": "string"}},
@ -672,12 +672,12 @@ def test_register_toolbox(tmpdir, logs_db):
},
},
{
"name": "keys",
"name": "Memory_keys",
"description": "Return a list of keys",
"arguments": {"properties": {}, "type": "object"},
},
{
"name": "set",
"name": "Memory_set",
"description": "Set something as a key",
"arguments": {
"properties": {
@ -702,15 +702,15 @@ def test_register_toolbox(tmpdir, logs_db):
"llm_version() -> str (plugin: llm.default_plugins.default_tools)\n\n"
" Return the installed version of llm\n\n"
"Filesystem:\n\n"
" list_files()\n\n"
" Filesystem_list_files()\n\n"
"Memory:\n\n"
" append(key: str, value: str)\n\n"
" Memory_append(key: str, value: str)\n\n"
" Append something as a key\n\n"
" get(key: str)\n\n"
" Memory_get(key: str)\n\n"
" Get something from a key\n\n"
" keys()\n\n"
" Memory_keys()\n\n"
" Return a list of keys\n\n"
" set(key: str, value: str)\n\n"
" Memory_set(key: str, value: str)\n\n"
" Set something as a key\n\n"
)

View file

@ -180,16 +180,81 @@ async def test_async_tools_run_tools_in_parallel():
@pytest.mark.asyncio
async def test_async_toolbox():
class Tools(llm.Toolbox):
def __init__(self):
self.prepared = False
async def go(self):
await asyncio.sleep(0)
return "This was async"
async def prepare_async(self):
await asyncio.sleep(0)
self.prepared = True
instance = Tools()
assert instance.prepared is False
model = llm.get_async_model("echo")
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "Tools_go"}]}),
tools=[Tools()],
tools=[instance],
)
output = await chain_response.text()
assert '"output": "This was async"' in output
assert instance.prepared is True
def test_toolbox_add_tool():
model = llm.get_model("echo")
class Tools(llm.Toolbox):
def __init__(self):
self.prepared = False
def original(self):
return "Original method"
def prepare(self):
self.prepared = True
def new_method():
return "New method"
tools = Tools()
tools.add_tool(new_method)
assert not tools.prepared
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "new_method"}]}),
tools=[tools],
)
output = chain_response.text()
assert '"output": "New method"' in output
assert tools.prepared
def test_toolbox_add_tool_with_pass_self():
model = llm.get_model("echo")
class Tools(llm.Toolbox):
def __init__(self, hotdog):
self.hotdog = hotdog
def original(self):
return "Original method"
def new_method(self):
return self.hotdog
tools = Tools("doghot")
tools.add_tool(new_method, pass_self=True)
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "new_method"}]}),
tools=[tools],
)
output = chain_response.text()
assert '"output": "doghot"' in output
@pytest.mark.vcr