mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Use tools in templates with llm chat, closes #1239
This commit is contained in:
parent
2f206d0e26
commit
c41c122239
2 changed files with 42 additions and 0 deletions
|
|
@ -1068,6 +1068,10 @@ def chat(
|
|||
raise click.ClickException(str(ex))
|
||||
if model_id is None and template_obj.model:
|
||||
model_id = template_obj.model
|
||||
if template_obj.tools:
|
||||
tools = [*template_obj.tools, *tools]
|
||||
if template_obj.functions and template_obj._functions_is_trusted:
|
||||
python_tools = [template_obj.functions, *python_tools]
|
||||
|
||||
# Figure out which model we are using
|
||||
if model_id is None:
|
||||
|
|
|
|||
|
|
@ -60,3 +60,41 @@ def test_chat_system_fragments_only_first_turn(tmpdir, mock_model, logs_db):
|
|||
assert len(sys_frags) == 1
|
||||
assert sys_frags[0]["response_id"] == first_id
|
||||
assert sys_frags[0]["response_id"] != second_id
|
||||
|
||||
|
||||
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
|
||||
def test_chat_template_loads_tools_into_logs(logs_db, templates_path):
|
||||
# Template that specifies tools; ensure chat picks them up
|
||||
(templates_path / "mytools.yaml").write_text(
|
||||
"model: echo\n" "tools:\n" "- llm_version\n" "- llm_time\n",
|
||||
"utf-8",
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
llm.cli.cli,
|
||||
["chat", "-t", "mytools"],
|
||||
input="hi\nquit\n",
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Verify a single response was logged for the conversation
|
||||
responses = list(logs_db["responses"].rows)
|
||||
assert len(responses) == 1
|
||||
assert responses[0]["prompt"] == "hi"
|
||||
response_id = responses[0]["id"]
|
||||
|
||||
# Tools from the template should be recorded against that response
|
||||
rows = list(
|
||||
logs_db.query(
|
||||
"""
|
||||
select tools.name from tools
|
||||
join tool_responses tr on tr.tool_id = tools.id
|
||||
where tr.response_id = ?
|
||||
order by tools.name
|
||||
""",
|
||||
[response_id],
|
||||
)
|
||||
)
|
||||
assert [r["name"] for r in rows] == ["llm_time", "llm_version"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue