diff --git a/llm/__init__.py b/llm/__init__.py index d10f588..c55c85d 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -26,7 +26,7 @@ from .embeddings import Collection from .templates import Template from .plugins import pm, load_plugins import click -from typing import Dict, List, Optional, Callable, Union +from typing import Any, Dict, List, Optional, Callable, Union import json import os import pathlib @@ -132,6 +132,25 @@ def get_fragment_loaders() -> Dict[ return _get_loaders(pm.hook.register_fragment_loaders) +def get_tools() -> Dict[str, List[Tool]]: + """Get tools registered by plugins.""" + load_plugins() + tools = {} + + def register(tool_or_function: Callable[..., Any], name: Optional[str] = None): + suffix = 0 + if not isinstance(tool_or_function, Tool): + tool_or_function = Tool.function(tool_or_function) + prefix_to_try = tool_or_function.name + while prefix_to_try in tools: + suffix += 1 + prefix_to_try = f"{prefix}_{suffix}" + tools[prefix_to_try] = tool_or_function + + pm.hook.register_tools(register=register) + return tools + + def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]: model_aliases = [] diff --git a/llm/hookspecs.py b/llm/hookspecs.py index 29a084a..a244b00 100644 --- a/llm/hookspecs.py +++ b/llm/hookspecs.py @@ -28,3 +28,8 @@ def register_template_loaders(register): @hookspec def register_fragment_loaders(register): "Register additional fragment loaders with prefixes" + + +@hookspec +def register_tools(register): + "Register functions that can be used as tools by the LLMs" diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 522dcc6..020973c 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -188,3 +188,53 @@ def test_register_fragment_loaders(logs_db, httpx_mock): {"content": "two:x", "source": "two"}, {"content": "three:x", "source": "three"}, ] + + +def test_register_tools(): + def upper(text: str) -> str: + """Convert text to uppercase.""" + return text.upper() + + def count_character_in_word(text: str, character: str) -> int: + """Count the number of occurrences of a character in a word.""" + return text.count(character) + + class ToolsPlugin: + __name__ = "ToolsPlugin" + + @hookimpl + def register_tools(self, register): + register(upper) + register(llm.Tool.function(count_character_in_word), name="count_chars") + + try: + plugins.pm.register(ToolsPlugin(), name="ToolsPlugin") + tools = llm.get_tools() + assert tools == { + "upper": llm.Tool( + name="upper", + description="Convert text to uppercase.", + input_schema={ + "properties": {"text": {"type": "string"}}, + "required": ["text"], + "type": "object", + }, + implementation=upper, + ), + "count_character_in_word": llm.Tool( + name="count_character_in_word", + description="Count the number of occurrences of a character in a word.", + input_schema={ + "properties": { + "text": {"type": "string"}, + "character": {"type": "string"}, + }, + "required": ["text", "character"], + "type": "object", + }, + implementation=count_character_in_word, + ), + } + finally: + plugins.pm.unregister(name="ToolsPlugin") + assert llm.get_tools() == {}