diff --git a/docs/logging.md b/docs/logging.md
index 4ef6746..9620234 100644
--- a/docs/logging.md
+++ b/docs/logging.md
@@ -405,7 +405,8 @@ CREATE TABLE "tool_results" (
[name] TEXT,
[output] TEXT,
[tool_call_id] TEXT,
- [instance_id] INTEGER REFERENCES [tool_instances]([id])
+ [instance_id] INTEGER REFERENCES [tool_instances]([id]),
+ [exception] TEXT
);
CREATE TABLE [tool_instances] (
[id] INTEGER PRIMARY KEY,
diff --git a/docs/python-api.md b/docs/python-api.md
index 85e7c8d..ccb724d 100644
--- a/docs/python-api.md
+++ b/docs/python-api.md
@@ -154,15 +154,18 @@ for response in chain.responses():
Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call.
-The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example:
+The method signature is `def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall)` - that first `tool` argument can be `None` if the model requests a tool be executed that has not been provided in the `tools=` list.
+
+Here's an example:
```python
import llm
+from typing import Optional
def upper(text: str) -> str:
"Convert text to uppercase."
return text.upper()
-def before_call(tool: llm.Tool, tool_call: llm.ToolCall):
+def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall):
print(f"About to call tool {tool.name} with arguments {tool_call.arguments}")
if tool.name == "upper" and "bad" in repr(tool_call.arguments):
raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'")
diff --git a/llm/cli.py b/llm/cli.py
index a97adde..4912b4b 100644
--- a/llm/cli.py
+++ b/llm/cli.py
@@ -1858,6 +1858,7 @@ def logs_list(
'name', tr.name,
'output', tr.output,
'tool_call_id', tr.tool_call_id,
+ 'exception', tr.exception,
'attachments', COALESCE(
(SELECT json_group_array(json_object(
'id', a.id,
@@ -2095,10 +2096,17 @@ def logs_list(
desc += f"<{attachment['content_length']:,} bytes>"
attachments += "\n - {}".format(desc)
click.echo(
- "- **{}**: `{}`
\n{}{}".format(
+ "- **{}**: `{}`
\n{}{}{}".format(
tool_result["name"],
tool_result["tool_call_id"],
textwrap.indent(tool_result["output"], " "),
+ (
+ "
\n **Error**: {}\n".format(
+ tool_result["exception"]
+ )
+ if tool_result["exception"]
+ else ""
+ ),
attachments,
)
)
@@ -3927,12 +3935,21 @@ def _debug_tool_call(_, tool_call, tool_result):
output += attachments
click.echo(
click.style(
- textwrap.indent(output, " ") + "\n",
+ textwrap.indent(output, " ") + ("\n" if not tool_result.exception else ""),
fg="green",
bold=True,
),
err=True,
)
+ if tool_result.exception:
+ click.echo(
+ click.style(
+ " Exception: {}".format(tool_result.exception),
+ fg="red",
+ bold=True,
+ ),
+ err=True,
+ )
def _approve_tool_call(_, tool_call):
diff --git a/llm/migrations.py b/llm/migrations.py
index 54cb777..f2ca046 100644
--- a/llm/migrations.py
+++ b/llm/migrations.py
@@ -413,3 +413,8 @@ def m020_tool_results_attachments(db):
),
pk=("tool_result_id", "attachment_id"),
)
+
+
+@migration
+def m021_tool_results_exception(db):
+ db["tool_results"].add_column("exception", str)
diff --git a/llm/models.py b/llm/models.py
index 0842503..9afbe97 100644
--- a/llm/models.py
+++ b/llm/models.py
@@ -273,6 +273,7 @@ class ToolResult:
attachments: List[Attachment] = field(default_factory=list)
tool_call_id: Optional[str] = None
instance: Optional[Toolbox] = None
+ exception: Optional[Exception] = None
@dataclass
@@ -284,9 +285,9 @@ class ToolOutput:
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
-BeforeCallSync = Callable[[Tool, ToolCall], None]
+BeforeCallSync = Callable[[Optional[Tool], ToolCall], None]
AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None]
-BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
+BeforeCallAsync = Callable[[Optional[Tool], ToolCall], Union[None, Awaitable[None]]]
AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
@@ -925,6 +926,16 @@ class _BaseResponse:
"output": tool_result.output,
"tool_call_id": tool_result.tool_call_id,
"instance_id": instance_id,
+ "exception": (
+ (
+ "{}: {}".format(
+ tool_result.exception.__class__.__name__,
+ str(tool_result.exception),
+ )
+ )
+ if tool_result.exception
+ else None
+ ),
}
)
.last_pk
@@ -989,16 +1000,8 @@ class Response(_BaseResponse):
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
- if tool is None:
- tool_results.append(
- ToolResult(
- name=tool_call.name,
- output='Error: tool "{}" does not exist'.format(tool_call.name),
- tool_call_id=tool_call.tool_call_id,
- )
- )
- continue
-
+ # Tool could be None if the tool was not found in the prompt tools,
+ # but we still call the before_call method:
if before_call:
cb_result = before_call(tool, tool_call)
if inspect.isawaitable(cb_result):
@@ -1007,12 +1010,25 @@ class Response(_BaseResponse):
"Please use an async chain/response or a synchronous callback."
)
+ if tool is None:
+ msg = 'tool "{}" does not exist'.format(tool_call.name)
+ tool_results.append(
+ ToolResult(
+ name=tool_call.name,
+ output="Error: " + msg,
+ tool_call_id=tool_call.tool_call_id,
+ exception=KeyError(msg),
+ )
+ )
+ continue
+
if not tool.implementation:
raise ValueError(
"No implementation available for tool: {}".format(tool_call.name)
)
attachments = []
+ exception = None
try:
if asyncio.iscoroutinefunction(tool.implementation):
@@ -1028,6 +1044,7 @@ class Response(_BaseResponse):
result = json.dumps(result, default=repr)
except Exception as ex:
result = f"Error: {ex}"
+ exception = ex
tool_result_obj = ToolResult(
name=tool_call.name,
@@ -1035,6 +1052,7 @@ class Response(_BaseResponse):
attachments=attachments,
tool_call_id=tool_call.tool_call_id,
instance=_get_instance(tool.implementation),
+ exception=exception,
)
if after_call:
@@ -1161,13 +1179,15 @@ class AsyncResponse(_BaseResponse):
for idx, tc in enumerate(tool_calls_list):
tool = tools_by_name.get(tc.name)
- if tool is None:
- raise CancelToolCall(f"Unknown tool: {tc.name}")
- if not tool.implementation:
- raise CancelToolCall(f"No implementation for tool: {tc.name}")
+ exception: Optional[Exception] = None
- # If it's an async implementation, wrap it
- if inspect.iscoroutinefunction(tool.implementation):
+ if tool is None:
+ output = f'Error: tool "{tc.name}" does not exist'
+ exception = KeyError(tc.name)
+ elif not tool.implementation:
+ output = f'Error: tool "{tc.name}" has no implementation'
+ exception = KeyError(tc.name)
+ elif inspect.iscoroutinefunction(tool.implementation):
async def run_async(tc=tc, tool=tool, idx=idx):
# before_call inside the task
@@ -1176,7 +1196,9 @@ class AsyncResponse(_BaseResponse):
if inspect.isawaitable(cb):
await cb
+ exception = None
attachments = []
+
try:
result = await tool.implementation(**tc.arguments)
if isinstance(result, ToolOutput):
@@ -1189,6 +1211,7 @@ class AsyncResponse(_BaseResponse):
)
except Exception as ex:
output = f"Error: {ex}"
+ exception = ex
tr = ToolResult(
name=tc.name,
@@ -1196,10 +1219,11 @@ class AsyncResponse(_BaseResponse):
attachments=attachments,
tool_call_id=tc.tool_call_id,
instance=_get_instance(tool.implementation),
+ exception=exception,
)
# after_call inside the task
- if after_call:
+ if tool is not None and after_call:
cb2 = after_call(tool, tc, tr)
if inspect.isawaitable(cb2):
await cb2
@@ -1215,34 +1239,44 @@ class AsyncResponse(_BaseResponse):
if inspect.isawaitable(cb):
await cb
+ exception = None
attachments = []
- try:
- res = tool.implementation(**tc.arguments)
- if inspect.isawaitable(res):
- res = await res
- if isinstance(res, ToolOutput):
- attachments.extend(res.attachments)
- res = res.output
- output = (
- res if isinstance(res, str) else json.dumps(res, default=repr)
+
+ if tool is None:
+ output = f'Error: tool "{tc.name}" does not exist'
+ exception = KeyError(tc.name)
+ else:
+ try:
+ res = tool.implementation(**tc.arguments)
+ if inspect.isawaitable(res):
+ res = await res
+ if isinstance(res, ToolOutput):
+ attachments.extend(res.attachments)
+ res = res.output
+ output = (
+ res
+ if isinstance(res, str)
+ else json.dumps(res, default=repr)
+ )
+ except Exception as ex:
+ output = f"Error: {ex}"
+ exception = ex
+
+ tr = ToolResult(
+ name=tc.name,
+ output=output,
+ attachments=attachments,
+ tool_call_id=tc.tool_call_id,
+ instance=_get_instance(tool.implementation),
+ exception=exception,
)
- except Exception as ex:
- output = f"Error: {ex}"
- tr = ToolResult(
- name=tc.name,
- output=output,
- attachments=attachments,
- tool_call_id=tc.tool_call_id,
- instance=_get_instance(tool.implementation),
- )
+ if tool is not None and after_call:
+ cb2 = after_call(tool, tc, tr)
+ if inspect.isawaitable(cb2):
+ await cb2
- if after_call:
- cb2 = after_call(tool, tc, tr)
- if inspect.isawaitable(cb2):
- await cb2
-
- indexed_results.append((idx, tr))
+ indexed_results.append((idx, tr))
# Await all async tasks in parallel
if async_tasks:
diff --git a/tests/test_plugins.py b/tests/test_plugins.py
index c47495b..d232279 100644
--- a/tests/test_plugins.py
+++ b/tests/test_plugins.py
@@ -459,12 +459,12 @@ def test_register_tools(tmpdir, logs_db):
('{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}', "[]"),
(
"",
- '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "attachments": []}]',
+ '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "exception": null, "attachments": []}]',
),
('{"tool_calls": [{"name": "upper", "arguments": {"text": "two"}}]}', "[]"),
(
"",
- '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "attachments": []}]',
+ '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "exception": null, "attachments": []}]',
),
(
'{"tool_calls": [{"name": "upper", "arguments": {"text": "three"}}]}',
@@ -472,7 +472,7 @@ def test_register_tools(tmpdir, logs_db):
),
(
"",
- '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "attachments": []}]',
+ '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "exception": null, "attachments": []}]',
),
)
# Test the --td option
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 293d2db..e803a49 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -385,3 +385,50 @@ async def test_tool_conversation_settings_async():
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
assert len(before_collected) == 2
assert len(after_collected) == 2
+
+
+ERROR_FUNCTION = """
+def trigger_error(msg: str):
+ raise Exception(msg)
+"""
+
+
+@pytest.mark.parametrize("async_", [False]) # Add True again
+def test_tool_errors(async_):
+ # https://github.com/simonw/llm/issues/1107
+ runner = CliRunner()
+ result = runner.invoke(
+ cli.cli,
+ (
+ [
+ "-m",
+ "echo",
+ "--functions",
+ ERROR_FUNCTION,
+ json.dumps(
+ {
+ "tool_calls": [
+ {"name": "trigger_error", "arguments": {"msg": "Error!"}}
+ ]
+ }
+ ),
+ ]
+ + (["--async"] if async_ else [])
+ ),
+ )
+ assert result.exit_code == 0
+ assert '"output": "Error: Error!"' in result.output
+ # llm logs --json output
+ log_json_result = runner.invoke(cli.cli, ["logs", "--json", "-c"])
+ assert log_json_result.exit_code == 0
+ log_data = json.loads(log_json_result.output)
+ assert len(log_data) == 2
+ assert log_data[1]["tool_results"][0]["exception"] == "Exception: Error!"
+ # llm logs -c output
+ log_text_result = runner.invoke(cli.cli, ["logs", "-c"])
+ assert log_text_result.exit_code == 0
+ assert (
+ "- **trigger_error**: `None`
\n"
+ " Error: Error!
\n"
+ " **Error**: Exception: Error!\n"
+ ) in log_text_result.output