mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-29 17:34:45 +00:00
Record which plugin a tool came from, including in DB - refs #1020
This commit is contained in:
parent
cb94047b46
commit
d2886d4692
6 changed files with 40 additions and 5 deletions
|
|
@ -354,7 +354,8 @@ CREATE TABLE [tools] (
|
|||
[hash] TEXT,
|
||||
[name] TEXT,
|
||||
[description] TEXT,
|
||||
[input_schema] TEXT
|
||||
[input_schema] TEXT,
|
||||
[plugin] TEXT
|
||||
);
|
||||
CREATE TABLE [tool_responses] (
|
||||
[tool_id] INTEGER REFERENCES [tools]([id]),
|
||||
|
|
|
|||
|
|
@ -139,6 +139,9 @@ def get_tools() -> Dict[str, Tool]:
|
|||
load_plugins()
|
||||
tools: Dict[str, Tool] = {}
|
||||
|
||||
# Variable to track current plugin name
|
||||
current_plugin_name = None
|
||||
|
||||
def register(
|
||||
tool_or_function: Union[Tool, Callable[..., Any]],
|
||||
name: Optional[str] = None,
|
||||
|
|
@ -148,6 +151,9 @@ def get_tools() -> Dict[str, Tool]:
|
|||
tool_or_function = Tool.function(tool_or_function, name=name)
|
||||
|
||||
tool = cast(Tool, tool_or_function)
|
||||
if current_plugin_name:
|
||||
tool.plugin = current_plugin_name
|
||||
|
||||
prefix = name or tool.name
|
||||
suffix = 0
|
||||
candidate = prefix
|
||||
|
|
@ -159,7 +165,16 @@ def get_tools() -> Dict[str, Tool]:
|
|||
|
||||
tools[candidate] = tool
|
||||
|
||||
pm.hook.register_tools(register=register)
|
||||
# Call each plugin's register_tools hook individually to track current_plugin_name
|
||||
for plugin in pm.get_plugins():
|
||||
current_plugin_name = pm.get_name(plugin)
|
||||
hook_caller = pm.hook.register_tools
|
||||
plugin_impls = [
|
||||
impl for impl in hook_caller.get_hookimpls() if impl.plugin is plugin
|
||||
]
|
||||
for impl in plugin_impls:
|
||||
impl.function(register=register)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2337,6 +2337,7 @@ def tools_list(json_, python_tools):
|
|||
name: {
|
||||
"description": tool.description,
|
||||
"arguments": tool.input_schema,
|
||||
"plugin": tool.plugin,
|
||||
}
|
||||
for name, tool in tools.items()
|
||||
},
|
||||
|
|
@ -2348,7 +2349,13 @@ def tools_list(json_, python_tools):
|
|||
sig = "()"
|
||||
if tool.implementation:
|
||||
sig = str(inspect.signature(tool.implementation))
|
||||
click.echo("{}{}".format(name, sig))
|
||||
click.echo(
|
||||
"{}{}{}".format(
|
||||
name,
|
||||
sig,
|
||||
" (plugin: {})".format(tool.plugin) if tool.plugin else "",
|
||||
)
|
||||
)
|
||||
if tool.description:
|
||||
click.echo(textwrap.indent(tool.description, " "))
|
||||
|
||||
|
|
|
|||
|
|
@ -368,3 +368,8 @@ def m017_tools_tables(db):
|
|||
("tool_id", "tools", "id"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@migration
|
||||
def m017_tools_plugin(db):
|
||||
db["tools"].add_column("plugin")
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ class Tool:
|
|||
description: Optional[str] = None
|
||||
input_schema: Dict = field(default_factory=dict)
|
||||
implementation: Optional[Callable] = None
|
||||
plugin: Optional[str] = None # plugin tool came from, e.g. 'llm_tools_sqlite'
|
||||
|
||||
def __post_init__(self):
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
|
|
@ -137,6 +138,8 @@ class Tool:
|
|||
"description": self.description,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
if self.plugin:
|
||||
to_hash["plugin"] = self.plugin
|
||||
return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ def test_register_tools(tmpdir):
|
|||
"type": "object",
|
||||
},
|
||||
implementation=upper,
|
||||
plugin="ToolsPlugin",
|
||||
),
|
||||
"count_chars": llm.Tool(
|
||||
name="count_chars",
|
||||
|
|
@ -239,6 +240,7 @@ def test_register_tools(tmpdir):
|
|||
"type": "object",
|
||||
},
|
||||
implementation=count_character_in_word,
|
||||
plugin="ToolsPlugin",
|
||||
),
|
||||
}
|
||||
# Test the CLI command
|
||||
|
|
@ -246,9 +248,9 @@ def test_register_tools(tmpdir):
|
|||
result = runner.invoke(cli.cli, ["tools", "list"])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == (
|
||||
"upper(text: str) -> str\n"
|
||||
"upper(text: str) -> str (plugin: ToolsPlugin)\n"
|
||||
" Convert text to uppercase.\n"
|
||||
"count_chars(text: str, character: str) -> int\n"
|
||||
"count_chars(text: str, character: str) -> int (plugin: ToolsPlugin)\n"
|
||||
" Count the number of occurrences of a character in a word.\n"
|
||||
)
|
||||
# And --json
|
||||
|
|
@ -262,6 +264,7 @@ def test_register_tools(tmpdir):
|
|||
"required": ["text"],
|
||||
"type": "object",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
"count_chars": {
|
||||
"description": "Count the number of occurrences of a character in a word.",
|
||||
|
|
@ -273,6 +276,7 @@ def test_register_tools(tmpdir):
|
|||
"required": ["text", "character"],
|
||||
"type": "object",
|
||||
},
|
||||
"plugin": "ToolsPlugin",
|
||||
},
|
||||
}
|
||||
# And test the --tools option
|
||||
|
|
|
|||
Loading…
Reference in a new issue