mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-15 01:03:11 +00:00
parent
e6dcd414a5
commit
30e0c4abe8
7 changed files with 158 additions and 51 deletions
|
|
@ -405,7 +405,8 @@ CREATE TABLE "tool_results" (
|
|||
[name] TEXT,
|
||||
[output] TEXT,
|
||||
[tool_call_id] TEXT,
|
||||
[instance_id] INTEGER REFERENCES [tool_instances]([id])
|
||||
[instance_id] INTEGER REFERENCES [tool_instances]([id]),
|
||||
[exception] TEXT
|
||||
);
|
||||
CREATE TABLE [tool_instances] (
|
||||
[id] INTEGER PRIMARY KEY,
|
||||
|
|
|
|||
|
|
@ -154,15 +154,18 @@ for response in chain.responses():
|
|||
|
||||
Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call.
|
||||
|
||||
The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example:
|
||||
The method signature is `def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall)` - that first `tool` argument can be `None` if the model requests a tool be executed that has not been provided in the `tools=` list.
|
||||
|
||||
Here's an example:
|
||||
```python
|
||||
import llm
|
||||
from typing import Optional
|
||||
|
||||
def upper(text: str) -> str:
|
||||
"Convert text to uppercase."
|
||||
return text.upper()
|
||||
|
||||
def before_call(tool: llm.Tool, tool_call: llm.ToolCall):
|
||||
def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall):
|
||||
print(f"About to call tool {tool.name} with arguments {tool_call.arguments}")
|
||||
if tool.name == "upper" and "bad" in repr(tool_call.arguments):
|
||||
raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'")
|
||||
|
|
|
|||
21
llm/cli.py
21
llm/cli.py
|
|
@ -1858,6 +1858,7 @@ def logs_list(
|
|||
'name', tr.name,
|
||||
'output', tr.output,
|
||||
'tool_call_id', tr.tool_call_id,
|
||||
'exception', tr.exception,
|
||||
'attachments', COALESCE(
|
||||
(SELECT json_group_array(json_object(
|
||||
'id', a.id,
|
||||
|
|
@ -2095,10 +2096,17 @@ def logs_list(
|
|||
desc += f"<{attachment['content_length']:,} bytes>"
|
||||
attachments += "\n - {}".format(desc)
|
||||
click.echo(
|
||||
"- **{}**: `{}`<br>\n{}{}".format(
|
||||
"- **{}**: `{}`<br>\n{}{}{}".format(
|
||||
tool_result["name"],
|
||||
tool_result["tool_call_id"],
|
||||
textwrap.indent(tool_result["output"], " "),
|
||||
(
|
||||
"<br>\n **Error**: {}\n".format(
|
||||
tool_result["exception"]
|
||||
)
|
||||
if tool_result["exception"]
|
||||
else ""
|
||||
),
|
||||
attachments,
|
||||
)
|
||||
)
|
||||
|
|
@ -3927,12 +3935,21 @@ def _debug_tool_call(_, tool_call, tool_result):
|
|||
output += attachments
|
||||
click.echo(
|
||||
click.style(
|
||||
textwrap.indent(output, " ") + "\n",
|
||||
textwrap.indent(output, " ") + ("\n" if not tool_result.exception else ""),
|
||||
fg="green",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
if tool_result.exception:
|
||||
click.echo(
|
||||
click.style(
|
||||
" Exception: {}".format(tool_result.exception),
|
||||
fg="red",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
def _approve_tool_call(_, tool_call):
|
||||
|
|
|
|||
|
|
@ -413,3 +413,8 @@ def m020_tool_results_attachments(db):
|
|||
),
|
||||
pk=("tool_result_id", "attachment_id"),
|
||||
)
|
||||
|
||||
|
||||
@migration
|
||||
def m021_tool_results_exception(db):
|
||||
db["tool_results"].add_column("exception", str)
|
||||
|
|
|
|||
120
llm/models.py
120
llm/models.py
|
|
@ -273,6 +273,7 @@ class ToolResult:
|
|||
attachments: List[Attachment] = field(default_factory=list)
|
||||
tool_call_id: Optional[str] = None
|
||||
instance: Optional[Toolbox] = None
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -284,9 +285,9 @@ class ToolOutput:
|
|||
|
||||
|
||||
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
|
||||
BeforeCallSync = Callable[[Tool, ToolCall], None]
|
||||
BeforeCallSync = Callable[[Optional[Tool], ToolCall], None]
|
||||
AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None]
|
||||
BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
BeforeCallAsync = Callable[[Optional[Tool], ToolCall], Union[None, Awaitable[None]]]
|
||||
AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
|
||||
|
||||
|
|
@ -925,6 +926,16 @@ class _BaseResponse:
|
|||
"output": tool_result.output,
|
||||
"tool_call_id": tool_result.tool_call_id,
|
||||
"instance_id": instance_id,
|
||||
"exception": (
|
||||
(
|
||||
"{}: {}".format(
|
||||
tool_result.exception.__class__.__name__,
|
||||
str(tool_result.exception),
|
||||
)
|
||||
)
|
||||
if tool_result.exception
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
.last_pk
|
||||
|
|
@ -989,16 +1000,8 @@ class Response(_BaseResponse):
|
|||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
for tool_call in self.tool_calls():
|
||||
tool = tools_by_name.get(tool_call.name)
|
||||
if tool is None:
|
||||
tool_results.append(
|
||||
ToolResult(
|
||||
name=tool_call.name,
|
||||
output='Error: tool "{}" does not exist'.format(tool_call.name),
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Tool could be None if the tool was not found in the prompt tools,
|
||||
# but we still call the before_call method:
|
||||
if before_call:
|
||||
cb_result = before_call(tool, tool_call)
|
||||
if inspect.isawaitable(cb_result):
|
||||
|
|
@ -1007,12 +1010,25 @@ class Response(_BaseResponse):
|
|||
"Please use an async chain/response or a synchronous callback."
|
||||
)
|
||||
|
||||
if tool is None:
|
||||
msg = 'tool "{}" does not exist'.format(tool_call.name)
|
||||
tool_results.append(
|
||||
ToolResult(
|
||||
name=tool_call.name,
|
||||
output="Error: " + msg,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
exception=KeyError(msg),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool.implementation:
|
||||
raise ValueError(
|
||||
"No implementation available for tool: {}".format(tool_call.name)
|
||||
)
|
||||
|
||||
attachments = []
|
||||
exception = None
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(tool.implementation):
|
||||
|
|
@ -1028,6 +1044,7 @@ class Response(_BaseResponse):
|
|||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
result = f"Error: {ex}"
|
||||
exception = ex
|
||||
|
||||
tool_result_obj = ToolResult(
|
||||
name=tool_call.name,
|
||||
|
|
@ -1035,6 +1052,7 @@ class Response(_BaseResponse):
|
|||
attachments=attachments,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
exception=exception,
|
||||
)
|
||||
|
||||
if after_call:
|
||||
|
|
@ -1161,13 +1179,15 @@ class AsyncResponse(_BaseResponse):
|
|||
|
||||
for idx, tc in enumerate(tool_calls_list):
|
||||
tool = tools_by_name.get(tc.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall(f"Unknown tool: {tc.name}")
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(f"No implementation for tool: {tc.name}")
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
# If it's an async implementation, wrap it
|
||||
if inspect.iscoroutinefunction(tool.implementation):
|
||||
if tool is None:
|
||||
output = f'Error: tool "{tc.name}" does not exist'
|
||||
exception = KeyError(tc.name)
|
||||
elif not tool.implementation:
|
||||
output = f'Error: tool "{tc.name}" has no implementation'
|
||||
exception = KeyError(tc.name)
|
||||
elif inspect.iscoroutinefunction(tool.implementation):
|
||||
|
||||
async def run_async(tc=tc, tool=tool, idx=idx):
|
||||
# before_call inside the task
|
||||
|
|
@ -1176,7 +1196,9 @@ class AsyncResponse(_BaseResponse):
|
|||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
|
||||
exception = None
|
||||
attachments = []
|
||||
|
||||
try:
|
||||
result = await tool.implementation(**tc.arguments)
|
||||
if isinstance(result, ToolOutput):
|
||||
|
|
@ -1189,6 +1211,7 @@ class AsyncResponse(_BaseResponse):
|
|||
)
|
||||
except Exception as ex:
|
||||
output = f"Error: {ex}"
|
||||
exception = ex
|
||||
|
||||
tr = ToolResult(
|
||||
name=tc.name,
|
||||
|
|
@ -1196,10 +1219,11 @@ class AsyncResponse(_BaseResponse):
|
|||
attachments=attachments,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
exception=exception,
|
||||
)
|
||||
|
||||
# after_call inside the task
|
||||
if after_call:
|
||||
if tool is not None and after_call:
|
||||
cb2 = after_call(tool, tc, tr)
|
||||
if inspect.isawaitable(cb2):
|
||||
await cb2
|
||||
|
|
@ -1215,34 +1239,44 @@ class AsyncResponse(_BaseResponse):
|
|||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
|
||||
exception = None
|
||||
attachments = []
|
||||
try:
|
||||
res = tool.implementation(**tc.arguments)
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if isinstance(res, ToolOutput):
|
||||
attachments.extend(res.attachments)
|
||||
res = res.output
|
||||
output = (
|
||||
res if isinstance(res, str) else json.dumps(res, default=repr)
|
||||
|
||||
if tool is None:
|
||||
output = f'Error: tool "{tc.name}" does not exist'
|
||||
exception = KeyError(tc.name)
|
||||
else:
|
||||
try:
|
||||
res = tool.implementation(**tc.arguments)
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
if isinstance(res, ToolOutput):
|
||||
attachments.extend(res.attachments)
|
||||
res = res.output
|
||||
output = (
|
||||
res
|
||||
if isinstance(res, str)
|
||||
else json.dumps(res, default=repr)
|
||||
)
|
||||
except Exception as ex:
|
||||
output = f"Error: {ex}"
|
||||
exception = ex
|
||||
|
||||
tr = ToolResult(
|
||||
name=tc.name,
|
||||
output=output,
|
||||
attachments=attachments,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
exception=exception,
|
||||
)
|
||||
except Exception as ex:
|
||||
output = f"Error: {ex}"
|
||||
|
||||
tr = ToolResult(
|
||||
name=tc.name,
|
||||
output=output,
|
||||
attachments=attachments,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
instance=_get_instance(tool.implementation),
|
||||
)
|
||||
if tool is not None and after_call:
|
||||
cb2 = after_call(tool, tc, tr)
|
||||
if inspect.isawaitable(cb2):
|
||||
await cb2
|
||||
|
||||
if after_call:
|
||||
cb2 = after_call(tool, tc, tr)
|
||||
if inspect.isawaitable(cb2):
|
||||
await cb2
|
||||
|
||||
indexed_results.append((idx, tr))
|
||||
indexed_results.append((idx, tr))
|
||||
|
||||
# Await all async tasks in parallel
|
||||
if async_tasks:
|
||||
|
|
|
|||
|
|
@ -459,12 +459,12 @@ def test_register_tools(tmpdir, logs_db):
|
|||
('{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}', "[]"),
|
||||
(
|
||||
"",
|
||||
'[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "attachments": []}]',
|
||||
'[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "exception": null, "attachments": []}]',
|
||||
),
|
||||
('{"tool_calls": [{"name": "upper", "arguments": {"text": "two"}}]}', "[]"),
|
||||
(
|
||||
"",
|
||||
'[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "attachments": []}]',
|
||||
'[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "exception": null, "attachments": []}]',
|
||||
),
|
||||
(
|
||||
'{"tool_calls": [{"name": "upper", "arguments": {"text": "three"}}]}',
|
||||
|
|
@ -472,7 +472,7 @@ def test_register_tools(tmpdir, logs_db):
|
|||
),
|
||||
(
|
||||
"",
|
||||
'[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "attachments": []}]',
|
||||
'[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "exception": null, "attachments": []}]',
|
||||
),
|
||||
)
|
||||
# Test the --td option
|
||||
|
|
|
|||
|
|
@ -385,3 +385,50 @@ async def test_tool_conversation_settings_async():
|
|||
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
|
||||
assert len(before_collected) == 2
|
||||
assert len(after_collected) == 2
|
||||
|
||||
|
||||
ERROR_FUNCTION = """
|
||||
def trigger_error(msg: str):
|
||||
raise Exception(msg)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_", [False]) # Add True again
|
||||
def test_tool_errors(async_):
|
||||
# https://github.com/simonw/llm/issues/1107
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli.cli,
|
||||
(
|
||||
[
|
||||
"-m",
|
||||
"echo",
|
||||
"--functions",
|
||||
ERROR_FUNCTION,
|
||||
json.dumps(
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "trigger_error", "arguments": {"msg": "Error!"}}
|
||||
]
|
||||
}
|
||||
),
|
||||
]
|
||||
+ (["--async"] if async_ else [])
|
||||
),
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert '"output": "Error: Error!"' in result.output
|
||||
# llm logs --json output
|
||||
log_json_result = runner.invoke(cli.cli, ["logs", "--json", "-c"])
|
||||
assert log_json_result.exit_code == 0
|
||||
log_data = json.loads(log_json_result.output)
|
||||
assert len(log_data) == 2
|
||||
assert log_data[1]["tool_results"][0]["exception"] == "Exception: Error!"
|
||||
# llm logs -c output
|
||||
log_text_result = runner.invoke(cli.cli, ["logs", "-c"])
|
||||
assert log_text_result.exit_code == 0
|
||||
assert (
|
||||
"- **trigger_error**: `None`<br>\n"
|
||||
" Error: Error!<br>\n"
|
||||
" **Error**: Exception: Error!\n"
|
||||
) in log_text_result.output
|
||||
|
|
|
|||
Loading…
Reference in a new issue