Refactor in prep for chat tools, refs #1004

This commit is contained in:
Simon Willison 2025-05-20 20:23:48 -07:00
parent a3414ed15d
commit a31ae86c20

View file

@ -766,60 +766,17 @@ def prompt(
if conversation:
prompt_method = conversation.prompt
extra_tools = []
if python_tools:
for code_or_path in python_tools:
extra_tools = _tools_from_code(code_or_path)
tool_functions = _gather_tools(tools, python_tools)
if tools or python_tools:
if tool_functions:
prompt_method = conversation.chain
kwargs["chain_limit"] = chain_limit
if tools_debug:
def debug_tool_call(_, tool_call, tool_result):
click.echo(
click.style(
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
fg="yellow",
bold=True,
),
err=True,
)
click.echo(
click.style(
" {}".format(tool_result.output),
fg="green",
bold=True,
),
err=True,
)
kwargs["after_call"] = debug_tool_call
kwargs["after_call"] = _debug_tool_call
if tools_approve:
kwargs["before_call"] = _approve_tool_call
kwargs["tools"] = tool_functions
def approve_tool_call(_, tool_call):
click.echo(
click.style(
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
fg="yellow",
bold=True,
),
err=True,
)
if not click.confirm("Approve tool call?"):
raise CancelToolCall("User cancelled tool call")
kwargs["before_call"] = approve_tool_call
# Look up all those tools
registered_tools = get_tools()
bad_tools = [tool for tool in tools if tool not in registered_tools]
if bad_tools:
raise click.ClickException(
"Tool(s) {} not found. Available tools: {}".format(
", ".join(bad_tools), ", ".join(registered_tools.keys())
)
)
kwargs["tools"] = [registered_tools[tool] for tool in tools] + extra_tools
try:
if async_:
@ -3616,3 +3573,52 @@ def _tools_from_code(code_or_path: str) -> List[Tool]:
if callable(value) and not name.startswith("_"):
tools.append(Tool.function(value))
return tools
def _debug_tool_call(_, tool_call, tool_result):
click.echo(
click.style(
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
fg="yellow",
bold=True,
),
err=True,
)
click.echo(
click.style(
" {}".format(tool_result.output),
fg="green",
bold=True,
),
err=True,
)
def _approve_tool_call(_, tool_call):
click.echo(
click.style(
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
fg="yellow",
bold=True,
),
err=True,
)
if not click.confirm("Approve tool call?"):
raise CancelToolCall("User cancelled tool call")
def _gather_tools(tools, python_tools):
tool_functions = []
if python_tools:
for code_or_path in python_tools:
tool_functions = _tools_from_code(code_or_path)
registered_tools = get_tools()
bad_tools = [tool for tool in tools if tool not in registered_tools]
if bad_tools:
raise click.ClickException(
"Tool(s) {} not found. Available tools: {}".format(
", ".join(bad_tools), ", ".join(registered_tools.keys())
)
)
tool_functions.extend(registered_tools[tool] for tool in tools)
return tool_functions