register_tools() hook, get_tools() utility

Refs #991
This commit is contained in:
Simon Willison 2025-05-11 19:44:02 -07:00
parent 8a3c461e46
commit 290c0f13f3
3 changed files with 75 additions and 1 deletions

View file

@ -26,7 +26,7 @@ from .embeddings import Collection
from .templates import Template
from .plugins import pm, load_plugins
import click
from typing import Dict, List, Optional, Callable, Union
from typing import Any, Dict, List, Optional, Callable, Union
import json
import os
import pathlib
@ -132,6 +132,25 @@ def get_fragment_loaders() -> Dict[
return _get_loaders(pm.hook.register_fragment_loaders)
def get_tools() -> Dict[str, List[Tool]]:
"""Get tools registered by plugins."""
load_plugins()
tools = {}
def register(tool_or_function: Callable[..., Any], name: Optional[str] = None):
suffix = 0
if not isinstance(tool_or_function, Tool):
tool_or_function = Tool.function(tool_or_function)
prefix_to_try = tool_or_function.name
while prefix_to_try in tools:
suffix += 1
prefix_to_try = f"{prefix}_{suffix}"
tools[prefix_to_try] = tool_or_function
pm.hook.register_tools(register=register)
return tools
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
model_aliases = []

View file

@ -28,3 +28,8 @@ def register_template_loaders(register):
@hookspec
def register_fragment_loaders(register):
"Register additional fragment loaders with prefixes"
@hookspec
def register_tools(register):
"Register functions that can be used as tools by the LLMs"

View file

@ -188,3 +188,53 @@ def test_register_fragment_loaders(logs_db, httpx_mock):
{"content": "two:x", "source": "two"},
{"content": "three:x", "source": "three"},
]
def test_register_tools():
def upper(text: str) -> str:
"""Convert text to uppercase."""
return text.upper()
def count_character_in_word(text: str, character: str) -> int:
"""Count the number of occurrences of a character in a word."""
return text.count(character)
class ToolsPlugin:
__name__ = "ToolsPlugin"
@hookimpl
def register_tools(self, register):
register(upper)
register(llm.Tool.function(count_character_in_word), name="count_chars")
try:
plugins.pm.register(ToolsPlugin(), name="ToolsPlugin")
tools = llm.get_tools()
assert tools == {
"upper": llm.Tool(
name="upper",
description="Convert text to uppercase.",
input_schema={
"properties": {"text": {"type": "string"}},
"required": ["text"],
"type": "object",
},
implementation=upper,
),
"count_character_in_word": llm.Tool(
name="count_character_in_word",
description="Count the number of occurrences of a character in a word.",
input_schema={
"properties": {
"text": {"type": "string"},
"character": {"type": "string"},
},
"required": ["text", "character"],
"type": "object",
},
implementation=count_character_in_word,
),
}
finally:
plugins.pm.unregister(name="ToolsPlugin")
assert llm.get_tools() == {}