diff --git a/llm/cli.py b/llm/cli.py index ff35dca..2e11e2c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1068,6 +1068,10 @@ def chat( raise click.ClickException(str(ex)) if model_id is None and template_obj.model: model_id = template_obj.model + if template_obj.tools: + tools = [*template_obj.tools, *tools] + if template_obj.functions and template_obj._functions_is_trusted: + python_tools = [template_obj.functions, *python_tools] # Figure out which model we are using if model_id is None: diff --git a/tests/test_chat_templates.py b/tests/test_chat_templates.py index 682a0bb..7687d78 100644 --- a/tests/test_chat_templates.py +++ b/tests/test_chat_templates.py @@ -60,3 +60,41 @@ def test_chat_system_fragments_only_first_turn(tmpdir, mock_model, logs_db): assert len(sys_frags) == 1 assert sys_frags[0]["response_id"] == first_id assert sys_frags[0]["response_id"] != second_id + + +@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows") +def test_chat_template_loads_tools_into_logs(logs_db, templates_path): + # Template that specifies tools; ensure chat picks them up + (templates_path / "mytools.yaml").write_text( + "model: echo\n" "tools:\n" "- llm_version\n" "- llm_time\n", + "utf-8", + ) + + runner = CliRunner() + result = runner.invoke( + llm.cli.cli, + ["chat", "-t", "mytools"], + input="hi\nquit\n", + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # Verify a single response was logged for the conversation + responses = list(logs_db["responses"].rows) + assert len(responses) == 1 + assert responses[0]["prompt"] == "hi" + response_id = responses[0]["id"] + + # Tools from the template should be recorded against that response + rows = list( + logs_db.query( + """ + select tools.name from tools + join tool_responses tr on tr.tool_id = tools.id + where tr.response_id = ? + order by tools.name + """, + [response_id], + ) + ) + assert [r["name"] for r in rows] == ["llm_time", "llm_version"]