mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Toolbox class for class-based tool collections (#1086)
* Toolbox class for class-based tool collections Refs #1059, #1058, #1057
This commit is contained in:
parent
abc4f473f4
commit
bb336d33a0
5 changed files with 603 additions and 104 deletions
|
|
@ -87,6 +87,100 @@ def register_tools(register):
|
|||
register(count_char, name="count_character_in_word")
|
||||
```
|
||||
|
||||
Functions are useful for simple tools, but some tools may have more advanced needs. You can also define tools as a class (known as a "toolbox"), which provides the following advantages:
|
||||
|
||||
- Toolbox tools can bundle multiple tools together
|
||||
- Toolbox tools can be configured, e.g. to give filesystem tools access to a specific directory
|
||||
- Toolbox instances can persist shared state in between tool invocations
|
||||
|
||||
Toolboxes are classes that extend `llm.Toolbox`. Any methods that do not begin with an underscore will be exposed as tool functions.
|
||||
|
||||
This example sets up key/value memory storage that can be used by the model:
|
||||
```python
|
||||
import llm
|
||||
|
||||
class Memory(llm.Toolbox):
|
||||
_memory = None
|
||||
|
||||
def _get_memory(self):
|
||||
if self._memory is None:
|
||||
self._memory = {}
|
||||
return self._memory
|
||||
|
||||
def set(self, key: str, value: str):
|
||||
"Set something as a key"
|
||||
self._get_memory()[key] = value
|
||||
|
||||
def get(self, key: str):
|
||||
"Get something from a key"
|
||||
return self._get_memory().get(key) or ""
|
||||
|
||||
def append(self, key: str, value: str):
|
||||
"Append something as a key"
|
||||
memory = self._get_memory()
|
||||
memory[key] = (memory.get(key) or "") + "\n" + value
|
||||
|
||||
def keys(self):
|
||||
"Return a list of keys"
|
||||
return list(self._get_memory().keys())
|
||||
|
||||
@llm.hookimpl
|
||||
def register_tools(register):
|
||||
register(Memory)
|
||||
```
|
||||
Once installed, this tool can be used like so:
|
||||
|
||||
```bash
|
||||
llm chat -T Memory
|
||||
```
|
||||
If a tool name starts with a capital letter it is assumed to be a toolbox class, not a regular tool function.
|
||||
|
||||
Here's an example session with the Memory tool:
|
||||
```
|
||||
Chatting with gpt-4.1-mini
|
||||
Type 'exit' or 'quit' to exit
|
||||
Type '!multi' to enter multiple lines, then '!end' to finish
|
||||
Type '!edit' to open your default editor and modify the prompt
|
||||
Type '!fragment <my_fragment> [<another_fragment> ...]' to insert one or more fragments
|
||||
> Remember my name is Henry
|
||||
|
||||
Tool call: Memory_set({'key': 'user_name', 'value': 'Henry'})
|
||||
null
|
||||
|
||||
Got it, Henry! I'll remember your name. How can I assist you today?
|
||||
> what keys are there?
|
||||
|
||||
Tool call: Memory_keys({})
|
||||
[
|
||||
"user_name"
|
||||
]
|
||||
|
||||
Currently, there is one key stored: "user_name". Would you like to add or retrieve any information?
|
||||
> read it
|
||||
|
||||
Tool call: Memory_get({'key': 'user_name'})
|
||||
Henry
|
||||
|
||||
The value stored under the key "user_name" is Henry. Is there anything else you'd like to do?
|
||||
> add Barrett to it
|
||||
|
||||
Tool call: Memory_append({'key': 'user_name', 'value': 'Barrett'})
|
||||
null
|
||||
|
||||
I have added "Barrett" to the key "user_name". If you want, I can now show you the updated value.
|
||||
> show value
|
||||
|
||||
Tool call: Memory_get({'key': 'user_name'})
|
||||
Henry
|
||||
Barrett
|
||||
|
||||
The value stored under the key "user_name" is now:
|
||||
Henry
|
||||
Barrett
|
||||
|
||||
Is there anything else you would like to do?
|
||||
```
|
||||
|
||||
(plugin-hooks-register-template-loaders)=
|
||||
## register_template_loaders(register)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .models import (
|
|||
Prompt,
|
||||
Response,
|
||||
Tool,
|
||||
Toolbox,
|
||||
ToolCall,
|
||||
)
|
||||
from .utils import schema_dsl, Fragment
|
||||
|
|
@ -27,7 +28,8 @@ from .embeddings import Collection
|
|||
from .templates import Template
|
||||
from .plugins import pm, load_plugins
|
||||
import click
|
||||
from typing import Any, Dict, List, Optional, Callable, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Callable, Type, Union
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
|
@ -55,6 +57,7 @@ __all__ = [
|
|||
"Response",
|
||||
"Template",
|
||||
"Tool",
|
||||
"Toolbox",
|
||||
"ToolCall",
|
||||
"user_dir",
|
||||
"schema_dsl",
|
||||
|
|
@ -134,36 +137,58 @@ def get_fragment_loaders() -> Dict[
|
|||
return _get_loaders(pm.hook.register_fragment_loaders)
|
||||
|
||||
|
||||
def get_tools() -> Dict[str, Tool]:
|
||||
"""Get tools registered by plugins."""
|
||||
def get_tools() -> Dict[str, Union[Tool, Type[Toolbox]]]:
|
||||
"""Return all tools (llm.Tool and llm.Toolbox) registered by plugins."""
|
||||
load_plugins()
|
||||
tools: Dict[str, Tool] = {}
|
||||
tools: Dict[str, Union[Tool, Type[Toolbox]]] = {}
|
||||
|
||||
# Variable to track current plugin name
|
||||
current_plugin_name = None
|
||||
|
||||
def register(
|
||||
tool_or_function: Union[Tool, Callable[..., Any]],
|
||||
tool_or_function: Union[Tool, Type[Toolbox], Callable[..., Any]],
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
# If they handed us a bare function, wrap it in a Tool
|
||||
if not isinstance(tool_or_function, Tool):
|
||||
tool_or_function = Tool.function(tool_or_function, name=name)
|
||||
tool: Union[Tool, Type[Toolbox], None] = None
|
||||
|
||||
tool = cast(Tool, tool_or_function)
|
||||
if current_plugin_name:
|
||||
tool.plugin = current_plugin_name
|
||||
# If it's a Toolbox class, set the plugin field on it
|
||||
if inspect.isclass(tool_or_function) and issubclass(tool_or_function, Toolbox):
|
||||
tool = tool_or_function
|
||||
if current_plugin_name:
|
||||
tool.plugin = current_plugin_name
|
||||
tool.name = name or tool.__name__
|
||||
|
||||
prefix = name or tool.name
|
||||
suffix = 0
|
||||
candidate = prefix
|
||||
# If it's already a Tool instance, use it directly
|
||||
elif isinstance(tool_or_function, Tool):
|
||||
tool = tool_or_function
|
||||
if name:
|
||||
tool.name = name
|
||||
if current_plugin_name:
|
||||
tool.plugin = current_plugin_name
|
||||
|
||||
# avoid name collisions
|
||||
while candidate in tools:
|
||||
suffix += 1
|
||||
candidate = f"{prefix}_{suffix}"
|
||||
# If it's a bare function, wrap it in a Tool
|
||||
else:
|
||||
tool = Tool.function(tool_or_function, name=name)
|
||||
if current_plugin_name:
|
||||
tool.plugin = current_plugin_name
|
||||
|
||||
tools[candidate] = tool
|
||||
# Get the name for the tool/toolbox
|
||||
if tool:
|
||||
# For Toolbox classes, use their name attribute or class name
|
||||
if inspect.isclass(tool) and issubclass(tool, Toolbox):
|
||||
prefix = name or getattr(tool, "name", tool.__name__) or ""
|
||||
else:
|
||||
prefix = name or tool.name or ""
|
||||
|
||||
suffix = 0
|
||||
candidate = prefix
|
||||
|
||||
# Avoid name collisions
|
||||
while candidate in tools:
|
||||
suffix += 1
|
||||
candidate = f"{prefix}_{suffix}"
|
||||
|
||||
tools[candidate] = tool
|
||||
|
||||
# Call each plugin's register_tools hook individually to track current_plugin_name
|
||||
for plugin in pm.get_plugins():
|
||||
|
|
|
|||
105
llm/cli.py
105
llm/cli.py
|
|
@ -17,6 +17,7 @@ from llm import (
|
|||
Response,
|
||||
Template,
|
||||
Tool,
|
||||
Toolbox,
|
||||
UnknownModelError,
|
||||
KeyModel,
|
||||
encode,
|
||||
|
|
@ -48,6 +49,7 @@ from .utils import (
|
|||
extract_fenced_code_block,
|
||||
find_unused_key,
|
||||
has_plugin_prefix,
|
||||
instantiate_from_spec,
|
||||
make_schema_id,
|
||||
maybe_fenced_code,
|
||||
mimetype_from_path,
|
||||
|
|
@ -73,7 +75,7 @@ import sqlite_utils
|
|||
from sqlite_utils.utils import rows_from_file, Format
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Any
|
||||
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Type, Any
|
||||
import warnings
|
||||
import yaml
|
||||
|
||||
|
|
@ -802,16 +804,16 @@ def prompt(
|
|||
if conversation:
|
||||
prompt_method = conversation.prompt
|
||||
|
||||
tool_functions = _gather_tools(tools, python_tools)
|
||||
tool_implementations = _gather_tools(tools, python_tools)
|
||||
|
||||
if tool_functions:
|
||||
if tool_implementations:
|
||||
prompt_method = conversation.chain
|
||||
kwargs["chain_limit"] = chain_limit
|
||||
if tools_debug:
|
||||
kwargs["after_call"] = _debug_tool_call
|
||||
if tools_approve:
|
||||
kwargs["before_call"] = _approve_tool_call
|
||||
kwargs["tools"] = tool_functions
|
||||
kwargs["tools"] = tool_implementations
|
||||
|
||||
try:
|
||||
if async_:
|
||||
|
|
@ -2488,39 +2490,81 @@ def tools():
|
|||
)
|
||||
def tools_list(json_, python_tools):
|
||||
"List available tools that have been provided by plugins"
|
||||
tools = get_tools()
|
||||
tools: Dict[str, Union[Tool, Type[Toolbox]]] = get_tools()
|
||||
if python_tools:
|
||||
for code_or_path in python_tools:
|
||||
for tool in _tools_from_code(code_or_path):
|
||||
tools[tool.name] = tool
|
||||
|
||||
output_tools = []
|
||||
output_toolboxes = []
|
||||
tool_objects = []
|
||||
toolbox_objects = []
|
||||
for name, tool in tools.items():
|
||||
if isinstance(tool, Tool):
|
||||
tool_objects.append(tool)
|
||||
output_tools.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": tool.description,
|
||||
"arguments": tool.input_schema,
|
||||
"plugin": tool.plugin,
|
||||
}
|
||||
)
|
||||
else:
|
||||
toolbox_objects.append(tool)
|
||||
output_toolboxes.append(
|
||||
{
|
||||
"name": name,
|
||||
"tools": [
|
||||
{
|
||||
"name": method["name"],
|
||||
"description": method["description"],
|
||||
"arguments": method["arguments"],
|
||||
}
|
||||
for method in tool.introspect_methods()
|
||||
],
|
||||
}
|
||||
)
|
||||
if json_:
|
||||
click.echo(
|
||||
json.dumps(
|
||||
{
|
||||
name: {
|
||||
"description": tool.description,
|
||||
"arguments": tool.input_schema,
|
||||
"plugin": tool.plugin,
|
||||
}
|
||||
for name, tool in tools.items()
|
||||
},
|
||||
{"tools": output_tools, "toolboxes": output_toolboxes},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for name, tool in tools.items():
|
||||
for tool in tool_objects:
|
||||
sig = "()"
|
||||
if tool.implementation:
|
||||
sig = str(inspect.signature(tool.implementation))
|
||||
click.echo(
|
||||
"{}{}{}".format(
|
||||
name,
|
||||
"{}{}{}\n".format(
|
||||
tool.name,
|
||||
sig,
|
||||
" (plugin: {})".format(tool.plugin) if tool.plugin else "",
|
||||
)
|
||||
)
|
||||
if tool.description:
|
||||
click.echo(textwrap.indent(tool.description, " "))
|
||||
click.echo(textwrap.indent(tool.description.strip(), " ") + "\n")
|
||||
for toolbox in toolbox_objects:
|
||||
click.echo(toolbox.name + ":\n")
|
||||
for method in toolbox.introspect_methods():
|
||||
sig = (
|
||||
str(inspect.signature(method["implementation"]))
|
||||
.replace("(self, ", "(")
|
||||
.replace("(self)", "()")
|
||||
)
|
||||
click.echo(
|
||||
" {}{}\n".format(
|
||||
method["name"],
|
||||
sig,
|
||||
)
|
||||
)
|
||||
if method["description"]:
|
||||
click.echo(
|
||||
textwrap.indent(method["description"].strip(), " ") + "\n"
|
||||
)
|
||||
|
||||
|
||||
@cli.group(
|
||||
|
|
@ -3844,21 +3888,36 @@ def _approve_tool_call(_, tool_call):
|
|||
raise CancelToolCall("User cancelled tool call")
|
||||
|
||||
|
||||
def _gather_tools(tools, python_tools):
|
||||
tool_functions = []
|
||||
def _gather_tools(
|
||||
tool_specs: List[str], python_tools: List[str]
|
||||
) -> List[Union[Tool, Type[Toolbox]]]:
|
||||
tools: List[Union[Tool, Type[Toolbox]]] = []
|
||||
if python_tools:
|
||||
for code_or_path in python_tools:
|
||||
tool_functions = _tools_from_code(code_or_path)
|
||||
tools.extend(_tools_from_code(code_or_path))
|
||||
registered_tools = get_tools()
|
||||
bad_tools = [tool for tool in tools if tool not in registered_tools]
|
||||
registered_classes = dict(
|
||||
(key, value)
|
||||
for key, value in registered_tools.items()
|
||||
if inspect.isclass(value)
|
||||
)
|
||||
bad_tools = [
|
||||
tool for tool in tool_specs if tool.split("(")[0] not in registered_tools
|
||||
]
|
||||
if bad_tools:
|
||||
raise click.ClickException(
|
||||
"Tool(s) {} not found. Available tools: {}".format(
|
||||
", ".join(bad_tools), ", ".join(registered_tools.keys())
|
||||
)
|
||||
)
|
||||
tool_functions.extend(registered_tools[tool] for tool in tools)
|
||||
return tool_functions
|
||||
for tool_spec in tool_specs:
|
||||
if not tool_spec[0].isupper():
|
||||
# It's a function
|
||||
tools.append(registered_tools[tool_spec])
|
||||
else:
|
||||
# It's a class
|
||||
tools.append(instantiate_from_spec(registered_classes, tool_spec))
|
||||
return tools
|
||||
|
||||
|
||||
def _get_conversation_tools(conversation, tools):
|
||||
|
|
|
|||
126
llm/models.py
126
llm/models.py
|
|
@ -118,18 +118,7 @@ class Tool:
|
|||
|
||||
def __post_init__(self):
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
self.input_schema = self._ensure_dict_schema(self.input_schema)
|
||||
|
||||
def _ensure_dict_schema(self, schema):
|
||||
"""Convert a Pydantic model to a JSON schema dict if needed."""
|
||||
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
|
||||
schema_dict = schema.model_json_schema()
|
||||
# Strip annoying "title" fields which are just the "name" in title case
|
||||
schema_dict.pop("title", None)
|
||||
for value in schema_dict.get("properties", {}).values():
|
||||
value.pop("title", None)
|
||||
return schema_dict
|
||||
return schema
|
||||
self.input_schema = _ensure_dict_schema(self.input_schema)
|
||||
|
||||
def hash(self):
|
||||
"""Hash for tool based on its name, description and input schema (preserving key order)"""
|
||||
|
|
@ -151,36 +140,87 @@ class Tool:
|
|||
- Building a Pydantic model for inputs by inspecting the function signature
|
||||
- Building a Pydantic model for the return value by using the function's return annotation
|
||||
"""
|
||||
signature = inspect.signature(function)
|
||||
type_hints = get_type_hints(function)
|
||||
|
||||
if not name and function.__name__ == "<lambda>":
|
||||
raise ValueError(
|
||||
"Cannot create a Tool from a lambda function without providing name="
|
||||
)
|
||||
|
||||
fields = {}
|
||||
for param_name, param in signature.parameters.items():
|
||||
# Determine the type annotation (default to string if missing)
|
||||
annotated_type = type_hints.get(param_name, str)
|
||||
|
||||
# Handle default value if present; if there's no default, use '...'
|
||||
if param.default is inspect.Parameter.empty:
|
||||
fields[param_name] = (annotated_type, ...)
|
||||
else:
|
||||
fields[param_name] = (annotated_type, param.default)
|
||||
|
||||
input_schema = create_model(f"{function.__name__}InputSchema", **fields)
|
||||
|
||||
return cls(
|
||||
name=name or function.__name__,
|
||||
description=function.__doc__ or None,
|
||||
input_schema=input_schema,
|
||||
input_schema=_get_arguments_input_schema(function, name),
|
||||
implementation=function,
|
||||
)
|
||||
|
||||
|
||||
ToolDef = Union[Tool, Callable[..., Any]]
|
||||
def _get_arguments_input_schema(function, name):
|
||||
signature = inspect.signature(function)
|
||||
type_hints = get_type_hints(function)
|
||||
fields = {}
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
# Determine the type annotation (default to string if missing)
|
||||
annotated_type = type_hints.get(param_name, str)
|
||||
|
||||
# Handle default value if present; if there's no default, use '...'
|
||||
if param.default is inspect.Parameter.empty:
|
||||
fields[param_name] = (annotated_type, ...)
|
||||
else:
|
||||
fields[param_name] = (annotated_type, param.default)
|
||||
|
||||
return create_model(f"{name}InputSchema", **fields)
|
||||
|
||||
|
||||
class Toolbox:
|
||||
_blocked = ("method_tools", "introspect_methods", "methods")
|
||||
name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def methods(cls):
|
||||
gathered = []
|
||||
for name in dir(cls):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if name in cls._blocked:
|
||||
continue
|
||||
method = getattr(cls, name)
|
||||
if callable(method):
|
||||
gathered.append(method)
|
||||
return gathered
|
||||
|
||||
def method_tools(self):
|
||||
"Returns a list of llm.Tool() for each method"
|
||||
for method_name in dir(self):
|
||||
if method_name.startswith("_") or method_name in self._blocked:
|
||||
continue
|
||||
method = getattr(self, method_name)
|
||||
# The attribute must be a bound method, i.e. inspect.ismethod()
|
||||
if callable(method) and inspect.ismethod(method):
|
||||
yield Tool.function(
|
||||
method,
|
||||
name="{}_{}".format(self.__class__.__name__, method_name),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def introspect_methods(cls):
|
||||
methods = []
|
||||
for method in cls.methods():
|
||||
arguments = _get_arguments_input_schema(method, method.__name__)
|
||||
methods.append(
|
||||
{
|
||||
"name": method.__name__,
|
||||
"description": (
|
||||
method.__doc__.strip() if method.__doc__ is not None else None
|
||||
),
|
||||
"arguments": _ensure_dict_schema(arguments),
|
||||
"implementation": method,
|
||||
}
|
||||
)
|
||||
return methods
|
||||
|
||||
|
||||
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -263,6 +303,8 @@ def _wrap_tools(tools: List[ToolDef]) -> List[Tool]:
|
|||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
wrapped_tools.append(tool)
|
||||
elif isinstance(tool, Toolbox):
|
||||
wrapped_tools.extend(tool.method_tools())
|
||||
elif callable(tool):
|
||||
wrapped_tools.append(Tool.function(tool))
|
||||
else:
|
||||
|
|
@ -1788,3 +1830,27 @@ def _conversation_name(text):
|
|||
if len(text) <= CONVERSATION_NAME_LENGTH:
|
||||
return text
|
||||
return text[: CONVERSATION_NAME_LENGTH - 1] + "…"
|
||||
|
||||
|
||||
def _ensure_dict_schema(schema):
|
||||
"""Convert a Pydantic model to a JSON schema dict if needed."""
|
||||
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
|
||||
schema_dict = schema.model_json_schema()
|
||||
_remove_titles_recursively(schema_dict)
|
||||
return schema_dict
|
||||
return schema
|
||||
|
||||
|
||||
def _remove_titles_recursively(obj):
|
||||
"""Recursively remove all 'title' fields from a nested dictionary."""
|
||||
if isinstance(obj, dict):
|
||||
# Remove title if present
|
||||
obj.pop("title", None)
|
||||
|
||||
# Recursively process all values
|
||||
for value in obj.values():
|
||||
_remove_titles_recursively(value)
|
||||
elif isinstance(obj, list):
|
||||
# Process each item in lists
|
||||
for item in obj:
|
||||
_remove_titles_recursively(item)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import importlib
|
|||
import json
|
||||
import llm
|
||||
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
|
||||
import pathlib
|
||||
import textwrap
|
||||
|
||||
|
||||
|
|
@ -263,46 +264,52 @@ def test_register_tools(tmpdir, logs_db):
|
|||
result = runner.invoke(cli.cli, ["tools", "list"])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == (
|
||||
"upper(text: str) -> str (plugin: ToolsPlugin)\n"
|
||||
" Convert text to uppercase.\n"
|
||||
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n"
|
||||
" Count the number of occurrences of a character in a word.\n"
|
||||
"output_as_json(text: str) (plugin: ToolsPlugin)\n"
|
||||
"upper(text: str) -> str (plugin: ToolsPlugin)\n\n"
|
||||
" Convert text to uppercase.\n\n"
|
||||
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n\n"
|
||||
" Count the number of occurrences of a character in a word.\n\n"
|
||||
"output_as_json(text: str) (plugin: ToolsPlugin)\n\n"
|
||||
)
|
||||
# And --json
|
||||
result2 = runner.invoke(cli.cli, ["tools", "list", "--json"])
|
||||
assert result2.exit_code == 0
|
||||
assert json.loads(result2.output) == {
|
||||
"upper": {
|
||||
"description": "Convert text to uppercase.",
|
||||
"arguments": {
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
"type": "object",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
"count_chars": {
|
||||
"description": "Count the number of occurrences of a character in a word.",
|
||||
"arguments": {
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"character": {"type": "string"},
|
||||
"tools": [
|
||||
{
|
||||
"name": "upper",
|
||||
"description": "Convert text to uppercase.",
|
||||
"arguments": {
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
"type": "object",
|
||||
},
|
||||
"required": ["text", "character"],
|
||||
"type": "object",
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
"output_as_json": {
|
||||
"description": None,
|
||||
"arguments": {
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
"type": "object",
|
||||
{
|
||||
"name": "count_chars",
|
||||
"description": "Count the number of occurrences of a character in a word.",
|
||||
"arguments": {
|
||||
"properties": {
|
||||
"text": {"type": "string"},
|
||||
"character": {"type": "string"},
|
||||
},
|
||||
"required": ["text", "character"],
|
||||
"type": "object",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
{
|
||||
"name": "output_as_json",
|
||||
"description": None,
|
||||
"arguments": {
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"],
|
||||
"type": "object",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
],
|
||||
"toolboxes": [],
|
||||
}
|
||||
# And test the --tools option
|
||||
functions_path = str(tmpdir / "functions.py")
|
||||
|
|
@ -333,6 +340,7 @@ def test_register_tools(tmpdir, logs_db):
|
|||
{"tool_calls": [{"name": "upper", "arguments": {"text": "hi"}}]}
|
||||
),
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result4.exit_code == 0
|
||||
assert '"output": "HI"' in result4.output
|
||||
|
|
@ -467,6 +475,253 @@ def test_register_tools(tmpdir, logs_db):
|
|||
assert llm.get_tools() == {}
|
||||
|
||||
|
||||
def test_register_toolbox(tmpdir, logs_db):
|
||||
class Memory(llm.Toolbox):
|
||||
_memory = None
|
||||
|
||||
def _get_memory(self):
|
||||
if self._memory is None:
|
||||
self._memory = {}
|
||||
return self._memory
|
||||
|
||||
def set(self, key: str, value: str):
|
||||
"Set something as a key"
|
||||
self._get_memory()[key] = value
|
||||
|
||||
def get(self, key: str):
|
||||
"Get something from a key"
|
||||
return self._get_memory().get(key) or ""
|
||||
|
||||
def append(self, key: str, value: str):
|
||||
"Append something as a key"
|
||||
memory = self._get_memory()
|
||||
memory[key] = (memory.get(key) or "") + "\n" + value
|
||||
|
||||
def keys(self):
|
||||
"Return a list of keys"
|
||||
return list(self._get_memory().keys())
|
||||
|
||||
class Filesystem(llm.Toolbox):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
def list_files(self):
|
||||
return [str(item) for item in pathlib.Path(self.path).glob("*")]
|
||||
|
||||
# Test the Python API
|
||||
model = llm.get_model("echo")
|
||||
memory = Memory()
|
||||
conversation = model.conversation(tools=[memory])
|
||||
accumulated = []
|
||||
|
||||
def after_call(tool, tool_call, tool_result):
|
||||
accumulated.append((tool.name, tool_call.arguments, tool_result.output))
|
||||
|
||||
conversation.chain(
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"arguments": {"key": "hello", "value": "world"},
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
after_call=after_call,
|
||||
).text()
|
||||
conversation.chain(
|
||||
json.dumps(
|
||||
{"tool_calls": [{"name": "Memory_get", "arguments": {"key": "hello"}}]}
|
||||
),
|
||||
after_call=after_call,
|
||||
).text()
|
||||
assert accumulated == [
|
||||
("Memory_set", {"key": "hello", "value": "world"}, "null"),
|
||||
("Memory_get", {"key": "hello"}, "world"),
|
||||
]
|
||||
assert memory._memory == {"hello": "world"}
|
||||
|
||||
# And for the Filesystem with state
|
||||
my_dir = pathlib.Path(tmpdir / "mine")
|
||||
my_dir.mkdir()
|
||||
(my_dir / "doc.txt").write_text("hi", "utf-8")
|
||||
conversation = model.conversation(tools=[Filesystem(my_dir)])
|
||||
accumulated.clear()
|
||||
conversation.chain(
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "Filesystem_list_files",
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
after_call=after_call,
|
||||
).text()
|
||||
assert accumulated == [
|
||||
("Filesystem_list_files", {}, json.dumps([str(my_dir / "doc.txt")]))
|
||||
]
|
||||
|
||||
# Now register them with a plugin and use it through the CLI
|
||||
|
||||
class ToolboxPlugin:
|
||||
__name__ = "ToolboxPlugin"
|
||||
|
||||
@hookimpl
|
||||
def register_tools(self, register):
|
||||
register(Memory)
|
||||
register(Filesystem)
|
||||
|
||||
try:
|
||||
plugins.pm.register(ToolboxPlugin(), name="ToolboxPlugin")
|
||||
tools = llm.get_tools()
|
||||
assert tools["Memory"] is Memory
|
||||
|
||||
runner = CliRunner()
|
||||
# llm tools --json
|
||||
result = runner.invoke(cli.cli, ["tools", "--json"])
|
||||
assert result.exit_code == 0
|
||||
assert json.loads(result.output) == {
|
||||
"tools": [],
|
||||
"toolboxes": [
|
||||
{
|
||||
"name": "Memory",
|
||||
"tools": [
|
||||
{
|
||||
"name": "append",
|
||||
"description": "Append something as a key",
|
||||
"arguments": {
|
||||
"properties": {
|
||||
"key": {"type": "string"},
|
||||
"value": {"type": "string"},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get something from a key",
|
||||
"arguments": {
|
||||
"properties": {"key": {"type": "string"}},
|
||||
"required": ["key"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "keys",
|
||||
"description": "Return a list of keys",
|
||||
"arguments": {"properties": {}, "type": "object"},
|
||||
},
|
||||
{
|
||||
"name": "set",
|
||||
"description": "Set something as a key",
|
||||
"arguments": {
|
||||
"properties": {
|
||||
"key": {"type": "string"},
|
||||
"value": {"type": "string"},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Filesystem",
|
||||
"tools": [
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": None,
|
||||
"arguments": {"properties": {}, "type": "object"},
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# llm tools (no JSON)
|
||||
result = runner.invoke(cli.cli, ["tools"])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == (
|
||||
"Memory:\n\n"
|
||||
" append(key: str, value: str)\n\n"
|
||||
" Append something as a key\n\n"
|
||||
" get(key: str)\n\n"
|
||||
" Get something from a key\n\n"
|
||||
" keys()\n\n"
|
||||
" Return a list of keys\n\n"
|
||||
" set(key: str, value: str)\n\n"
|
||||
" Set something as a key\n\n"
|
||||
"Filesystem:\n\n"
|
||||
" list_files()\n\n"
|
||||
)
|
||||
|
||||
# Test the CLI running a toolbox prompt
|
||||
result3 = runner.invoke(
|
||||
cli.cli,
|
||||
[
|
||||
"prompt",
|
||||
"-T",
|
||||
"Memory",
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "Memory_set",
|
||||
"arguments": {"key": "hi", "value": "two"},
|
||||
},
|
||||
{"name": "Memory_get", "arguments": {"key": "hi"}},
|
||||
]
|
||||
}
|
||||
),
|
||||
"-m",
|
||||
"echo",
|
||||
],
|
||||
)
|
||||
assert result3.exit_code == 0
|
||||
tool_results = json.loads(
|
||||
"[" + result3.output.split('"tool_results": [')[1].split("]")[0] + "]"
|
||||
)
|
||||
assert tool_results == [
|
||||
{"name": "Memory_set", "output": "null", "tool_call_id": None},
|
||||
{"name": "Memory_get", "output": "two", "tool_call_id": None},
|
||||
]
|
||||
|
||||
# Test the CLI running a configured toolbox prompt
|
||||
my_dir2 = pathlib.Path(tmpdir / "mine2")
|
||||
my_dir2.mkdir()
|
||||
other_path = my_dir2 / "other.txt"
|
||||
other_path.write_text("hi", "utf-8")
|
||||
result4 = runner.invoke(
|
||||
cli.cli,
|
||||
[
|
||||
"prompt",
|
||||
"-T",
|
||||
"Filesystem({})".format(json.dumps(str(my_dir2))),
|
||||
json.dumps({"tool_calls": [{"name": "Filesystem_list_files"}]}),
|
||||
"-m",
|
||||
"echo",
|
||||
],
|
||||
)
|
||||
assert result4.exit_code == 0
|
||||
tool_results = json.loads(
|
||||
"[" + result4.output.split('"tool_results": [')[1].rsplit("]", 1)[0] + "]"
|
||||
)
|
||||
assert tool_results == [
|
||||
{
|
||||
"name": "Filesystem_list_files",
|
||||
"output": json.dumps([str(other_path)]),
|
||||
"tool_call_id": None,
|
||||
}
|
||||
]
|
||||
|
||||
finally:
|
||||
plugins.pm.unregister(name="ToolboxPlugin")
|
||||
|
||||
|
||||
def test_plugins_command():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli.cli, ["plugins"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue