From a31ae86c204692180d8866473142a9d03b9016e3 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 20 May 2025 20:23:48 -0700 Subject: [PATCH] Refactor in prep for chat tools, refs #1004 --- llm/cli.py | 102 ++++++++++++++++++++++++++++------------------------- 1 file changed, 54 insertions(+), 48 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index e659094..8ebe102 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -766,60 +766,17 @@ def prompt( if conversation: prompt_method = conversation.prompt - extra_tools = [] - if python_tools: - for code_or_path in python_tools: - extra_tools = _tools_from_code(code_or_path) + tool_functions = _gather_tools(tools, python_tools) - if tools or python_tools: + if tool_functions: prompt_method = conversation.chain kwargs["chain_limit"] = chain_limit if tools_debug: - - def debug_tool_call(_, tool_call, tool_result): - click.echo( - click.style( - "Tool call: {}({})".format(tool_call.name, tool_call.arguments), - fg="yellow", - bold=True, - ), - err=True, - ) - click.echo( - click.style( - " {}".format(tool_result.output), - fg="green", - bold=True, - ), - err=True, - ) - - kwargs["after_call"] = debug_tool_call + kwargs["after_call"] = _debug_tool_call if tools_approve: + kwargs["before_call"] = _approve_tool_call + kwargs["tools"] = tool_functions - def approve_tool_call(_, tool_call): - click.echo( - click.style( - "Tool call: {}({})".format(tool_call.name, tool_call.arguments), - fg="yellow", - bold=True, - ), - err=True, - ) - if not click.confirm("Approve tool call?"): - raise CancelToolCall("User cancelled tool call") - - kwargs["before_call"] = approve_tool_call - # Look up all those tools - registered_tools = get_tools() - bad_tools = [tool for tool in tools if tool not in registered_tools] - if bad_tools: - raise click.ClickException( - "Tool(s) {} not found. Available tools: {}".format( - ", ".join(bad_tools), ", ".join(registered_tools.keys()) - ) - ) - kwargs["tools"] = [registered_tools[tool] for tool in tools] + extra_tools try: if async_: @@ -3616,3 +3573,52 @@ def _tools_from_code(code_or_path: str) -> List[Tool]: if callable(value) and not name.startswith("_"): tools.append(Tool.function(value)) return tools + + +def _debug_tool_call(_, tool_call, tool_result): + click.echo( + click.style( + "Tool call: {}({})".format(tool_call.name, tool_call.arguments), + fg="yellow", + bold=True, + ), + err=True, + ) + click.echo( + click.style( + " {}".format(tool_result.output), + fg="green", + bold=True, + ), + err=True, + ) + + +def _approve_tool_call(_, tool_call): + click.echo( + click.style( + "Tool call: {}({})".format(tool_call.name, tool_call.arguments), + fg="yellow", + bold=True, + ), + err=True, + ) + if not click.confirm("Approve tool call?"): + raise CancelToolCall("User cancelled tool call") + + +def _gather_tools(tools, python_tools): + tool_functions = [] + if python_tools: + for code_or_path in python_tools: + tool_functions = _tools_from_code(code_or_path) + registered_tools = get_tools() + bad_tools = [tool for tool in tools if tool not in registered_tools] + if bad_tools: + raise click.ClickException( + "Tool(s) {} not found. Available tools: {}".format( + ", ".join(bad_tools), ", ".join(registered_tools.keys()) + ) + ) + tool_functions.extend(registered_tools[tool] for tool in tools) + return tool_functions