Toolbox class for class-based tool collections (#1086)

* Toolbox class for class-based tool collections

Refs #1059, #1058, #1057
This commit is contained in:
Simon Willison 2025-05-25 22:42:52 -07:00 committed by GitHub
parent abc4f473f4
commit bb336d33a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 603 additions and 104 deletions

View file

@ -87,6 +87,100 @@ def register_tools(register):
register(count_char, name="count_character_in_word")
```
Functions are useful for simple tools, but some tools may have more advanced needs. You can also define tools as a class (known as a "toolbox"), which provides the following advantages:
- Toolbox tools can bundle multiple tools together
- Toolbox tools can be configured, e.g. to give filesystem tools access to a specific directory
- Toolbox instances can persist shared state in between tool invocations
Toolboxes are classes that extend `llm.Toolbox`. Any methods that do not begin with an underscore will be exposed as tool functions.
This example sets up key/value memory storage that can be used by the model:
```python
import llm
class Memory(llm.Toolbox):
_memory = None
def _get_memory(self):
if self._memory is None:
self._memory = {}
return self._memory
def set(self, key: str, value: str):
"Set something as a key"
self._get_memory()[key] = value
def get(self, key: str):
"Get something from a key"
return self._get_memory().get(key) or ""
def append(self, key: str, value: str):
"Append something as a key"
memory = self._get_memory()
memory[key] = (memory.get(key) or "") + "\n" + value
def keys(self):
"Return a list of keys"
return list(self._get_memory().keys())
@llm.hookimpl
def register_tools(register):
register(Memory)
```
Once installed, this tool can be used like so:
```bash
llm chat -T Memory
```
If a tool name starts with a capital letter it is assumed to be a toolbox class, not a regular tool function.
Here's an example session with the Memory tool:
```
Chatting with gpt-4.1-mini
Type 'exit' or 'quit' to exit
Type '!multi' to enter multiple lines, then '!end' to finish
Type '!edit' to open your default editor and modify the prompt
Type '!fragment <my_fragment> [<another_fragment> ...]' to insert one or more fragments
> Remember my name is Henry
Tool call: Memory_set({'key': 'user_name', 'value': 'Henry'})
null
Got it, Henry! I'll remember your name. How can I assist you today?
> what keys are there?
Tool call: Memory_keys({})
[
"user_name"
]
Currently, there is one key stored: "user_name". Would you like to add or retrieve any information?
> read it
Tool call: Memory_get({'key': 'user_name'})
Henry
The value stored under the key "user_name" is Henry. Is there anything else you'd like to do?
> add Barrett to it
Tool call: Memory_append({'key': 'user_name', 'value': 'Barrett'})
null
I have added "Barrett" to the key "user_name". If you want, I can now show you the updated value.
> show value
Tool call: Memory_get({'key': 'user_name'})
Henry
Barrett
The value stored under the key "user_name" is now:
Henry
Barrett
Is there anything else you would like to do?
```
(plugin-hooks-register-template-loaders)=
## register_template_loaders(register)

View file

@ -20,6 +20,7 @@ from .models import (
Prompt,
Response,
Tool,
Toolbox,
ToolCall,
)
from .utils import schema_dsl, Fragment
@ -27,7 +28,8 @@ from .embeddings import Collection
from .templates import Template
from .plugins import pm, load_plugins
import click
from typing import Any, Dict, List, Optional, Callable, Union, cast
from typing import Any, Dict, List, Optional, Callable, Type, Union
import inspect
import json
import os
import pathlib
@ -55,6 +57,7 @@ __all__ = [
"Response",
"Template",
"Tool",
"Toolbox",
"ToolCall",
"user_dir",
"schema_dsl",
@ -134,36 +137,58 @@ def get_fragment_loaders() -> Dict[
return _get_loaders(pm.hook.register_fragment_loaders)
def get_tools() -> Dict[str, Tool]:
"""Get tools registered by plugins."""
def get_tools() -> Dict[str, Union[Tool, Type[Toolbox]]]:
"""Return all tools (llm.Tool and llm.Toolbox) registered by plugins."""
load_plugins()
tools: Dict[str, Tool] = {}
tools: Dict[str, Union[Tool, Type[Toolbox]]] = {}
# Variable to track current plugin name
current_plugin_name = None
def register(
tool_or_function: Union[Tool, Callable[..., Any]],
tool_or_function: Union[Tool, Type[Toolbox], Callable[..., Any]],
name: Optional[str] = None,
) -> None:
# If they handed us a bare function, wrap it in a Tool
if not isinstance(tool_or_function, Tool):
tool_or_function = Tool.function(tool_or_function, name=name)
tool: Union[Tool, Type[Toolbox], None] = None
tool = cast(Tool, tool_or_function)
if current_plugin_name:
tool.plugin = current_plugin_name
# If it's a Toolbox class, set the plugin field on it
if inspect.isclass(tool_or_function) and issubclass(tool_or_function, Toolbox):
tool = tool_or_function
if current_plugin_name:
tool.plugin = current_plugin_name
tool.name = name or tool.__name__
prefix = name or tool.name
suffix = 0
candidate = prefix
# If it's already a Tool instance, use it directly
elif isinstance(tool_or_function, Tool):
tool = tool_or_function
if name:
tool.name = name
if current_plugin_name:
tool.plugin = current_plugin_name
# avoid name collisions
while candidate in tools:
suffix += 1
candidate = f"{prefix}_{suffix}"
# If it's a bare function, wrap it in a Tool
else:
tool = Tool.function(tool_or_function, name=name)
if current_plugin_name:
tool.plugin = current_plugin_name
tools[candidate] = tool
# Get the name for the tool/toolbox
if tool:
# For Toolbox classes, use their name attribute or class name
if inspect.isclass(tool) and issubclass(tool, Toolbox):
prefix = name or getattr(tool, "name", tool.__name__) or ""
else:
prefix = name or tool.name or ""
suffix = 0
candidate = prefix
# Avoid name collisions
while candidate in tools:
suffix += 1
candidate = f"{prefix}_{suffix}"
tools[candidate] = tool
# Call each plugin's register_tools hook individually to track current_plugin_name
for plugin in pm.get_plugins():

View file

@ -17,6 +17,7 @@ from llm import (
Response,
Template,
Tool,
Toolbox,
UnknownModelError,
KeyModel,
encode,
@ -48,6 +49,7 @@ from .utils import (
extract_fenced_code_block,
find_unused_key,
has_plugin_prefix,
instantiate_from_spec,
make_schema_id,
maybe_fenced_code,
mimetype_from_path,
@ -73,7 +75,7 @@ import sqlite_utils
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Any
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Type, Any
import warnings
import yaml
@ -802,16 +804,16 @@ def prompt(
if conversation:
prompt_method = conversation.prompt
tool_functions = _gather_tools(tools, python_tools)
tool_implementations = _gather_tools(tools, python_tools)
if tool_functions:
if tool_implementations:
prompt_method = conversation.chain
kwargs["chain_limit"] = chain_limit
if tools_debug:
kwargs["after_call"] = _debug_tool_call
if tools_approve:
kwargs["before_call"] = _approve_tool_call
kwargs["tools"] = tool_functions
kwargs["tools"] = tool_implementations
try:
if async_:
@ -2488,39 +2490,81 @@ def tools():
)
def tools_list(json_, python_tools):
"List available tools that have been provided by plugins"
tools = get_tools()
tools: Dict[str, Union[Tool, Type[Toolbox]]] = get_tools()
if python_tools:
for code_or_path in python_tools:
for tool in _tools_from_code(code_or_path):
tools[tool.name] = tool
output_tools = []
output_toolboxes = []
tool_objects = []
toolbox_objects = []
for name, tool in tools.items():
if isinstance(tool, Tool):
tool_objects.append(tool)
output_tools.append(
{
"name": name,
"description": tool.description,
"arguments": tool.input_schema,
"plugin": tool.plugin,
}
)
else:
toolbox_objects.append(tool)
output_toolboxes.append(
{
"name": name,
"tools": [
{
"name": method["name"],
"description": method["description"],
"arguments": method["arguments"],
}
for method in tool.introspect_methods()
],
}
)
if json_:
click.echo(
json.dumps(
{
name: {
"description": tool.description,
"arguments": tool.input_schema,
"plugin": tool.plugin,
}
for name, tool in tools.items()
},
{"tools": output_tools, "toolboxes": output_toolboxes},
indent=2,
)
)
else:
for name, tool in tools.items():
for tool in tool_objects:
sig = "()"
if tool.implementation:
sig = str(inspect.signature(tool.implementation))
click.echo(
"{}{}{}".format(
name,
"{}{}{}\n".format(
tool.name,
sig,
" (plugin: {})".format(tool.plugin) if tool.plugin else "",
)
)
if tool.description:
click.echo(textwrap.indent(tool.description, " "))
click.echo(textwrap.indent(tool.description.strip(), " ") + "\n")
for toolbox in toolbox_objects:
click.echo(toolbox.name + ":\n")
for method in toolbox.introspect_methods():
sig = (
str(inspect.signature(method["implementation"]))
.replace("(self, ", "(")
.replace("(self)", "()")
)
click.echo(
" {}{}\n".format(
method["name"],
sig,
)
)
if method["description"]:
click.echo(
textwrap.indent(method["description"].strip(), " ") + "\n"
)
@cli.group(
@ -3844,21 +3888,36 @@ def _approve_tool_call(_, tool_call):
raise CancelToolCall("User cancelled tool call")
def _gather_tools(tools, python_tools):
tool_functions = []
def _gather_tools(
tool_specs: List[str], python_tools: List[str]
) -> List[Union[Tool, Type[Toolbox]]]:
tools: List[Union[Tool, Type[Toolbox]]] = []
if python_tools:
for code_or_path in python_tools:
tool_functions = _tools_from_code(code_or_path)
tools.extend(_tools_from_code(code_or_path))
registered_tools = get_tools()
bad_tools = [tool for tool in tools if tool not in registered_tools]
registered_classes = dict(
(key, value)
for key, value in registered_tools.items()
if inspect.isclass(value)
)
bad_tools = [
tool for tool in tool_specs if tool.split("(")[0] not in registered_tools
]
if bad_tools:
raise click.ClickException(
"Tool(s) {} not found. Available tools: {}".format(
", ".join(bad_tools), ", ".join(registered_tools.keys())
)
)
tool_functions.extend(registered_tools[tool] for tool in tools)
return tool_functions
for tool_spec in tool_specs:
if not tool_spec[0].isupper():
# It's a function
tools.append(registered_tools[tool_spec])
else:
# It's a class
tools.append(instantiate_from_spec(registered_classes, tool_spec))
return tools
def _get_conversation_tools(conversation, tools):

View file

@ -118,18 +118,7 @@ class Tool:
def __post_init__(self):
# Convert Pydantic model to JSON schema if needed
self.input_schema = self._ensure_dict_schema(self.input_schema)
def _ensure_dict_schema(self, schema):
"""Convert a Pydantic model to a JSON schema dict if needed."""
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
schema_dict = schema.model_json_schema()
# Strip annoying "title" fields which are just the "name" in title case
schema_dict.pop("title", None)
for value in schema_dict.get("properties", {}).values():
value.pop("title", None)
return schema_dict
return schema
self.input_schema = _ensure_dict_schema(self.input_schema)
def hash(self):
"""Hash for tool based on its name, description and input schema (preserving key order)"""
@ -151,36 +140,87 @@ class Tool:
- Building a Pydantic model for inputs by inspecting the function signature
- Building a Pydantic model for the return value by using the function's return annotation
"""
signature = inspect.signature(function)
type_hints = get_type_hints(function)
if not name and function.__name__ == "<lambda>":
raise ValueError(
"Cannot create a Tool from a lambda function without providing name="
)
fields = {}
for param_name, param in signature.parameters.items():
# Determine the type annotation (default to string if missing)
annotated_type = type_hints.get(param_name, str)
# Handle default value if present; if there's no default, use '...'
if param.default is inspect.Parameter.empty:
fields[param_name] = (annotated_type, ...)
else:
fields[param_name] = (annotated_type, param.default)
input_schema = create_model(f"{function.__name__}InputSchema", **fields)
return cls(
name=name or function.__name__,
description=function.__doc__ or None,
input_schema=input_schema,
input_schema=_get_arguments_input_schema(function, name),
implementation=function,
)
ToolDef = Union[Tool, Callable[..., Any]]
def _get_arguments_input_schema(function, name):
signature = inspect.signature(function)
type_hints = get_type_hints(function)
fields = {}
for param_name, param in signature.parameters.items():
if param_name == "self":
continue
# Determine the type annotation (default to string if missing)
annotated_type = type_hints.get(param_name, str)
# Handle default value if present; if there's no default, use '...'
if param.default is inspect.Parameter.empty:
fields[param_name] = (annotated_type, ...)
else:
fields[param_name] = (annotated_type, param.default)
return create_model(f"{name}InputSchema", **fields)
class Toolbox:
_blocked = ("method_tools", "introspect_methods", "methods")
name: Optional[str] = None
@classmethod
def methods(cls):
gathered = []
for name in dir(cls):
if name.startswith("_"):
continue
if name in cls._blocked:
continue
method = getattr(cls, name)
if callable(method):
gathered.append(method)
return gathered
def method_tools(self):
"Returns a list of llm.Tool() for each method"
for method_name in dir(self):
if method_name.startswith("_") or method_name in self._blocked:
continue
method = getattr(self, method_name)
# The attribute must be a bound method, i.e. inspect.ismethod()
if callable(method) and inspect.ismethod(method):
yield Tool.function(
method,
name="{}_{}".format(self.__class__.__name__, method_name),
)
@classmethod
def introspect_methods(cls):
methods = []
for method in cls.methods():
arguments = _get_arguments_input_schema(method, method.__name__)
methods.append(
{
"name": method.__name__,
"description": (
method.__doc__.strip() if method.__doc__ is not None else None
),
"arguments": _ensure_dict_schema(arguments),
"implementation": method,
}
)
return methods
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
@dataclass
@ -263,6 +303,8 @@ def _wrap_tools(tools: List[ToolDef]) -> List[Tool]:
for tool in tools:
if isinstance(tool, Tool):
wrapped_tools.append(tool)
elif isinstance(tool, Toolbox):
wrapped_tools.extend(tool.method_tools())
elif callable(tool):
wrapped_tools.append(Tool.function(tool))
else:
@ -1788,3 +1830,27 @@ def _conversation_name(text):
if len(text) <= CONVERSATION_NAME_LENGTH:
return text
return text[: CONVERSATION_NAME_LENGTH - 1] + ""
def _ensure_dict_schema(schema):
"""Convert a Pydantic model to a JSON schema dict if needed."""
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
schema_dict = schema.model_json_schema()
_remove_titles_recursively(schema_dict)
return schema_dict
return schema
def _remove_titles_recursively(obj):
"""Recursively remove all 'title' fields from a nested dictionary."""
if isinstance(obj, dict):
# Remove title if present
obj.pop("title", None)
# Recursively process all values
for value in obj.values():
_remove_titles_recursively(value)
elif isinstance(obj, list):
# Process each item in lists
for item in obj:
_remove_titles_recursively(item)

View file

@ -4,6 +4,7 @@ import importlib
import json
import llm
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
import pathlib
import textwrap
@ -263,46 +264,52 @@ def test_register_tools(tmpdir, logs_db):
result = runner.invoke(cli.cli, ["tools", "list"])
assert result.exit_code == 0
assert result.output == (
"upper(text: str) -> str (plugin: ToolsPlugin)\n"
" Convert text to uppercase.\n"
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n"
" Count the number of occurrences of a character in a word.\n"
"output_as_json(text: str) (plugin: ToolsPlugin)\n"
"upper(text: str) -> str (plugin: ToolsPlugin)\n\n"
" Convert text to uppercase.\n\n"
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n\n"
" Count the number of occurrences of a character in a word.\n\n"
"output_as_json(text: str) (plugin: ToolsPlugin)\n\n"
)
# And --json
result2 = runner.invoke(cli.cli, ["tools", "list", "--json"])
assert result2.exit_code == 0
assert json.loads(result2.output) == {
"upper": {
"description": "Convert text to uppercase.",
"arguments": {
"properties": {"text": {"type": "string"}},
"required": ["text"],
"type": "object",
},
"plugin": "ToolsPlugin",
},
"count_chars": {
"description": "Count the number of occurrences of a character in a word.",
"arguments": {
"properties": {
"text": {"type": "string"},
"character": {"type": "string"},
"tools": [
{
"name": "upper",
"description": "Convert text to uppercase.",
"arguments": {
"properties": {"text": {"type": "string"}},
"required": ["text"],
"type": "object",
},
"required": ["text", "character"],
"type": "object",
"plugin": "ToolsPlugin",
},
"plugin": "ToolsPlugin",
},
"output_as_json": {
"description": None,
"arguments": {
"properties": {"text": {"type": "string"}},
"required": ["text"],
"type": "object",
{
"name": "count_chars",
"description": "Count the number of occurrences of a character in a word.",
"arguments": {
"properties": {
"text": {"type": "string"},
"character": {"type": "string"},
},
"required": ["text", "character"],
"type": "object",
},
"plugin": "ToolsPlugin",
},
"plugin": "ToolsPlugin",
},
{
"name": "output_as_json",
"description": None,
"arguments": {
"properties": {"text": {"type": "string"}},
"required": ["text"],
"type": "object",
},
"plugin": "ToolsPlugin",
},
],
"toolboxes": [],
}
# And test the --tools option
functions_path = str(tmpdir / "functions.py")
@ -333,6 +340,7 @@ def test_register_tools(tmpdir, logs_db):
{"tool_calls": [{"name": "upper", "arguments": {"text": "hi"}}]}
),
],
catch_exceptions=False,
)
assert result4.exit_code == 0
assert '"output": "HI"' in result4.output
@ -467,6 +475,253 @@ def test_register_tools(tmpdir, logs_db):
assert llm.get_tools() == {}
def test_register_toolbox(tmpdir, logs_db):
class Memory(llm.Toolbox):
_memory = None
def _get_memory(self):
if self._memory is None:
self._memory = {}
return self._memory
def set(self, key: str, value: str):
"Set something as a key"
self._get_memory()[key] = value
def get(self, key: str):
"Get something from a key"
return self._get_memory().get(key) or ""
def append(self, key: str, value: str):
"Append something as a key"
memory = self._get_memory()
memory[key] = (memory.get(key) or "") + "\n" + value
def keys(self):
"Return a list of keys"
return list(self._get_memory().keys())
class Filesystem(llm.Toolbox):
def __init__(self, path: str):
self.path = path
def list_files(self):
return [str(item) for item in pathlib.Path(self.path).glob("*")]
# Test the Python API
model = llm.get_model("echo")
memory = Memory()
conversation = model.conversation(tools=[memory])
accumulated = []
def after_call(tool, tool_call, tool_result):
accumulated.append((tool.name, tool_call.arguments, tool_result.output))
conversation.chain(
json.dumps(
{
"tool_calls": [
{
"name": "Memory_set",
"arguments": {"key": "hello", "value": "world"},
}
]
}
),
after_call=after_call,
).text()
conversation.chain(
json.dumps(
{"tool_calls": [{"name": "Memory_get", "arguments": {"key": "hello"}}]}
),
after_call=after_call,
).text()
assert accumulated == [
("Memory_set", {"key": "hello", "value": "world"}, "null"),
("Memory_get", {"key": "hello"}, "world"),
]
assert memory._memory == {"hello": "world"}
# And for the Filesystem with state
my_dir = pathlib.Path(tmpdir / "mine")
my_dir.mkdir()
(my_dir / "doc.txt").write_text("hi", "utf-8")
conversation = model.conversation(tools=[Filesystem(my_dir)])
accumulated.clear()
conversation.chain(
json.dumps(
{
"tool_calls": [
{
"name": "Filesystem_list_files",
}
]
}
),
after_call=after_call,
).text()
assert accumulated == [
("Filesystem_list_files", {}, json.dumps([str(my_dir / "doc.txt")]))
]
# Now register them with a plugin and use it through the CLI
class ToolboxPlugin:
__name__ = "ToolboxPlugin"
@hookimpl
def register_tools(self, register):
register(Memory)
register(Filesystem)
try:
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
tools = llm.get_tools()
assert tools["Memory"] is Memory
runner = CliRunner()
# llm tools --json
result = runner.invoke(cli.cli, ["tools", "--json"])
assert result.exit_code == 0
assert json.loads(result.output) == {
"tools": [],
"toolboxes": [
{
"name": "Memory",
"tools": [
{
"name": "append",
"description": "Append something as a key",
"arguments": {
"properties": {
"key": {"type": "string"},
"value": {"type": "string"},
},
"required": ["key", "value"],
"type": "object",
},
},
{
"name": "get",
"description": "Get something from a key",
"arguments": {
"properties": {"key": {"type": "string"}},
"required": ["key"],
"type": "object",
},
},
{
"name": "keys",
"description": "Return a list of keys",
"arguments": {"properties": {}, "type": "object"},
},
{
"name": "set",
"description": "Set something as a key",
"arguments": {
"properties": {
"key": {"type": "string"},
"value": {"type": "string"},
},
"required": ["key", "value"],
"type": "object",
},
},
],
},
{
"name": "Filesystem",
"tools": [
{
"name": "list_files",
"description": None,
"arguments": {"properties": {}, "type": "object"},
}
],
},
],
}
# llm tools (no JSON)
result = runner.invoke(cli.cli, ["tools"])
assert result.exit_code == 0
assert result.output == (
"Memory:\n\n"
" append(key: str, value: str)\n\n"
" Append something as a key\n\n"
" get(key: str)\n\n"
" Get something from a key\n\n"
" keys()\n\n"
" Return a list of keys\n\n"
" set(key: str, value: str)\n\n"
" Set something as a key\n\n"
"Filesystem:\n\n"
" list_files()\n\n"
)
# Test the CLI running a toolbox prompt
result3 = runner.invoke(
cli.cli,
[
"prompt",
"-T",
"Memory",
json.dumps(
{
"tool_calls": [
{
"name": "Memory_set",
"arguments": {"key": "hi", "value": "two"},
},
{"name": "Memory_get", "arguments": {"key": "hi"}},
]
}
),
"-m",
"echo",
],
)
assert result3.exit_code == 0
tool_results = json.loads(
"[" + result3.output.split('"tool_results": [')[1].split("]")[0] + "]"
)
assert tool_results == [
{"name": "Memory_set", "output": "null", "tool_call_id": None},
{"name": "Memory_get", "output": "two", "tool_call_id": None},
]
# Test the CLI running a configured toolbox prompt
my_dir2 = pathlib.Path(tmpdir / "mine2")
my_dir2.mkdir()
other_path = my_dir2 / "other.txt"
other_path.write_text("hi", "utf-8")
result4 = runner.invoke(
cli.cli,
[
"prompt",
"-T",
"Filesystem({})".format(json.dumps(str(my_dir2))),
json.dumps({"tool_calls": [{"name": "Filesystem_list_files"}]}),
"-m",
"echo",
],
)
assert result4.exit_code == 0
tool_results = json.loads(
"[" + result4.output.split('"tool_results": [')[1].rsplit("]", 1)[0] + "]"
)
assert tool_results == [
{
"name": "Filesystem_list_files",
"output": json.dumps([str(other_path)]),
"tool_call_id": None,
}
]
finally:
plugins.pm.unregister(name="ToolboxPlugin")
def test_plugins_command():
runner = CliRunner()
result = runner.invoke(cli.cli, ["plugins"])