mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-27 22:44:01 +00:00
Refactor in prep for chat tools, refs #1004
This commit is contained in:
parent
a3414ed15d
commit
a31ae86c20
1 changed files with 54 additions and 48 deletions
102
llm/cli.py
102
llm/cli.py
|
|
@ -766,60 +766,17 @@ def prompt(
|
|||
if conversation:
|
||||
prompt_method = conversation.prompt
|
||||
|
||||
extra_tools = []
|
||||
if python_tools:
|
||||
for code_or_path in python_tools:
|
||||
extra_tools = _tools_from_code(code_or_path)
|
||||
tool_functions = _gather_tools(tools, python_tools)
|
||||
|
||||
if tools or python_tools:
|
||||
if tool_functions:
|
||||
prompt_method = conversation.chain
|
||||
kwargs["chain_limit"] = chain_limit
|
||||
if tools_debug:
|
||||
|
||||
def debug_tool_call(_, tool_call, tool_result):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
" {}".format(tool_result.output),
|
||||
fg="green",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
|
||||
kwargs["after_call"] = debug_tool_call
|
||||
kwargs["after_call"] = _debug_tool_call
|
||||
if tools_approve:
|
||||
kwargs["before_call"] = _approve_tool_call
|
||||
kwargs["tools"] = tool_functions
|
||||
|
||||
def approve_tool_call(_, tool_call):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
if not click.confirm("Approve tool call?"):
|
||||
raise CancelToolCall("User cancelled tool call")
|
||||
|
||||
kwargs["before_call"] = approve_tool_call
|
||||
# Look up all those tools
|
||||
registered_tools = get_tools()
|
||||
bad_tools = [tool for tool in tools if tool not in registered_tools]
|
||||
if bad_tools:
|
||||
raise click.ClickException(
|
||||
"Tool(s) {} not found. Available tools: {}".format(
|
||||
", ".join(bad_tools), ", ".join(registered_tools.keys())
|
||||
)
|
||||
)
|
||||
kwargs["tools"] = [registered_tools[tool] for tool in tools] + extra_tools
|
||||
try:
|
||||
if async_:
|
||||
|
||||
|
|
@ -3616,3 +3573,52 @@ def _tools_from_code(code_or_path: str) -> List[Tool]:
|
|||
if callable(value) and not name.startswith("_"):
|
||||
tools.append(Tool.function(value))
|
||||
return tools
|
||||
|
||||
|
||||
def _debug_tool_call(_, tool_call, tool_result):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
" {}".format(tool_result.output),
|
||||
fg="green",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
def _approve_tool_call(_, tool_call):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Tool call: {}({})".format(tool_call.name, tool_call.arguments),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
if not click.confirm("Approve tool call?"):
|
||||
raise CancelToolCall("User cancelled tool call")
|
||||
|
||||
|
||||
def _gather_tools(tools, python_tools):
|
||||
tool_functions = []
|
||||
if python_tools:
|
||||
for code_or_path in python_tools:
|
||||
tool_functions = _tools_from_code(code_or_path)
|
||||
registered_tools = get_tools()
|
||||
bad_tools = [tool for tool in tools if tool not in registered_tools]
|
||||
if bad_tools:
|
||||
raise click.ClickException(
|
||||
"Tool(s) {} not found. Available tools: {}".format(
|
||||
", ".join(bad_tools), ", ".join(registered_tools.keys())
|
||||
)
|
||||
)
|
||||
tool_functions.extend(registered_tools[tool] for tool in tools)
|
||||
return tool_functions
|
||||
|
|
|
|||
Loading…
Reference in a new issue