llm chat -c and llm -c carry forward tools, closes #1020

This commit is contained in:
Simon Willison 2025-05-23 21:10:51 -07:00
parent 0da2072ade
commit 36477cf9e5
3 changed files with 78 additions and 0 deletions

View file

@ -688,6 +688,9 @@ def prompt(
except UnknownModelError as ex:
raise click.ClickException(str(ex))
if conversation_tools := _get_conversation_tools(conversation, tools):
tools = conversation_tools
# Figure out which model we are using
if model_id is None:
if conversation:
@ -985,6 +988,9 @@ def chat(
except UnknownModelError as ex:
raise click.ClickException(str(ex))
if conversation_tools := _get_conversation_tools(conversation, tools):
tools = conversation_tools
template_obj = None
if template:
params = dict(param)
@ -3690,3 +3696,12 @@ def _gather_tools(tools, python_tools):
)
tool_functions.extend(registered_tools[tool] for tool in tools)
return tool_functions
def _get_conversation_tools(conversation, tools):
if conversation and not tools and conversation.responses:
# Copy plugin tools from first response in conversation
initial_tools = conversation.responses[0].prompt.tools
if initial_tools:
# Only tools from plugins:
return [tool.name for tool in initial_tools if tool.plugin]

View file

@ -579,6 +579,7 @@ class _BaseResponse:
# In this case we don't have a reference to the actual Python code
# but that's OK, we should not need it for prompts deserialized from DB
implementation=None,
plugin=tool_row["plugin"],
)
for tool_row in db.query(
"""

View file

@ -4,6 +4,7 @@ import importlib
import json
import llm
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
import re
import textwrap
@ -315,6 +316,67 @@ def test_register_tools(tmpdir, logs_db):
tool_row = [row for row in logs_db["tools"].rows][0]
assert tool_row["name"] == "upper"
assert tool_row["plugin"] == "ToolsPlugin"
# Start with a tool, use llm -c to reuse the same tool
result5 = runner.invoke(
cli.cli,
[
"prompt",
"-m",
"echo",
"--tool",
"upper",
json.dumps(
{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}
),
"--td",
],
)
assert result5.exit_code == 0
assert (
runner.invoke(
cli.cli,
[
"-c",
json.dumps(
{
"tool_calls": [
{"name": "upper", "arguments": {"text": "two"}}
]
}
),
],
).exit_code
== 0
)
# Now do it again with llm chat -c
assert (
runner.invoke(
cli.cli,
["chat", "-c"],
input=(
json.dumps(
{
"tool_calls": [
{"name": "upper", "arguments": {"text": "three"}}
]
}
)
+ "\nquit\n"
),
catch_exceptions=False,
).exit_code
== 0
)
# Should have logged three tool uses in llm logs -c -n 0
log_output = runner.invoke(cli.cli, ["logs", "-c", "-n", "10"]).output
log_pattern = re.compile(
r"""tool_calls.*?"text": "one".*?ONE.*?"""
r"""tool_calls.*?"text": "two".*?TWO.*?"""
r"""tool_calls.*?"text": "three".*?THREE""",
re.DOTALL,
)
assert log_pattern.search(log_output)
finally:
plugins.pm.unregister(name="ToolsPlugin")
assert llm.get_tools() == {}