mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-18 20:21:03 +00:00
llm chat -c and llm -c carry forward tools, closes #1020
This commit is contained in:
parent
0da2072ade
commit
36477cf9e5
3 changed files with 78 additions and 0 deletions
15
llm/cli.py
15
llm/cli.py
|
|
@ -688,6 +688,9 @@ def prompt(
|
|||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
if conversation_tools := _get_conversation_tools(conversation, tools):
|
||||
tools = conversation_tools
|
||||
|
||||
# Figure out which model we are using
|
||||
if model_id is None:
|
||||
if conversation:
|
||||
|
|
@ -985,6 +988,9 @@ def chat(
|
|||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
if conversation_tools := _get_conversation_tools(conversation, tools):
|
||||
tools = conversation_tools
|
||||
|
||||
template_obj = None
|
||||
if template:
|
||||
params = dict(param)
|
||||
|
|
@ -3690,3 +3696,12 @@ def _gather_tools(tools, python_tools):
|
|||
)
|
||||
tool_functions.extend(registered_tools[tool] for tool in tools)
|
||||
return tool_functions
|
||||
|
||||
|
||||
def _get_conversation_tools(conversation, tools):
|
||||
if conversation and not tools and conversation.responses:
|
||||
# Copy plugin tools from first response in conversation
|
||||
initial_tools = conversation.responses[0].prompt.tools
|
||||
if initial_tools:
|
||||
# Only tools from plugins:
|
||||
return [tool.name for tool in initial_tools if tool.plugin]
|
||||
|
|
|
|||
|
|
@ -579,6 +579,7 @@ class _BaseResponse:
|
|||
# In this case we don't have a reference to the actual Python code
|
||||
# but that's OK, we should not need it for prompts deserialized from DB
|
||||
implementation=None,
|
||||
plugin=tool_row["plugin"],
|
||||
)
|
||||
for tool_row in db.query(
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import importlib
|
|||
import json
|
||||
import llm
|
||||
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
|
||||
|
|
@ -315,6 +316,67 @@ def test_register_tools(tmpdir, logs_db):
|
|||
tool_row = [row for row in logs_db["tools"].rows][0]
|
||||
assert tool_row["name"] == "upper"
|
||||
assert tool_row["plugin"] == "ToolsPlugin"
|
||||
|
||||
# Start with a tool, use llm -c to reuse the same tool
|
||||
result5 = runner.invoke(
|
||||
cli.cli,
|
||||
[
|
||||
"prompt",
|
||||
"-m",
|
||||
"echo",
|
||||
"--tool",
|
||||
"upper",
|
||||
json.dumps(
|
||||
{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}
|
||||
),
|
||||
"--td",
|
||||
],
|
||||
)
|
||||
assert result5.exit_code == 0
|
||||
assert (
|
||||
runner.invoke(
|
||||
cli.cli,
|
||||
[
|
||||
"-c",
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "upper", "arguments": {"text": "two"}}
|
||||
]
|
||||
}
|
||||
),
|
||||
],
|
||||
).exit_code
|
||||
== 0
|
||||
)
|
||||
# Now do it again with llm chat -c
|
||||
assert (
|
||||
runner.invoke(
|
||||
cli.cli,
|
||||
["chat", "-c"],
|
||||
input=(
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "upper", "arguments": {"text": "three"}}
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\nquit\n"
|
||||
),
|
||||
catch_exceptions=False,
|
||||
).exit_code
|
||||
== 0
|
||||
)
|
||||
# Should have logged three tool uses in llm logs -c -n 0
|
||||
log_output = runner.invoke(cli.cli, ["logs", "-c", "-n", "10"]).output
|
||||
log_pattern = re.compile(
|
||||
r"""tool_calls.*?"text": "one".*?ONE.*?"""
|
||||
r"""tool_calls.*?"text": "two".*?TWO.*?"""
|
||||
r"""tool_calls.*?"text": "three".*?THREE""",
|
||||
re.DOTALL,
|
||||
)
|
||||
assert log_pattern.search(log_output)
|
||||
finally:
|
||||
plugins.pm.unregister(name="ToolsPlugin")
|
||||
assert llm.get_tools() == {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue