diff --git a/docs/logging.md b/docs/logging.md index 4ef6746..9620234 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -405,7 +405,8 @@ CREATE TABLE "tool_results" ( [name] TEXT, [output] TEXT, [tool_call_id] TEXT, - [instance_id] INTEGER REFERENCES [tool_instances]([id]) + [instance_id] INTEGER REFERENCES [tool_instances]([id]), + [exception] TEXT ); CREATE TABLE [tool_instances] ( [id] INTEGER PRIMARY KEY, diff --git a/docs/python-api.md b/docs/python-api.md index 85e7c8d..ccb724d 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -154,15 +154,18 @@ for response in chain.responses(): Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call. -The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example: +The method signature is `def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall)` - that first `tool` argument can be `None` if the model requests a tool be executed that has not been provided in the `tools=` list. + +Here's an example: ```python import llm +from typing import Optional def upper(text: str) -> str: "Convert text to uppercase." return text.upper() -def before_call(tool: llm.Tool, tool_call: llm.ToolCall): +def before_call(tool: Optional[llm.Tool], tool_call: llm.ToolCall): print(f"About to call tool {tool.name} with arguments {tool_call.arguments}") if tool.name == "upper" and "bad" in repr(tool_call.arguments): raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'") diff --git a/llm/cli.py b/llm/cli.py index a97adde..4912b4b 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1858,6 +1858,7 @@ def logs_list( 'name', tr.name, 'output', tr.output, 'tool_call_id', tr.tool_call_id, + 'exception', tr.exception, 'attachments', COALESCE( (SELECT json_group_array(json_object( 'id', a.id, @@ -2095,10 +2096,17 @@ def logs_list( desc += f"<{attachment['content_length']:,} bytes>" attachments += "\n - {}".format(desc) click.echo( - "- **{}**: `{}`
\n{}{}".format( + "- **{}**: `{}`
\n{}{}{}".format( tool_result["name"], tool_result["tool_call_id"], textwrap.indent(tool_result["output"], " "), + ( + "
\n **Error**: {}\n".format( + tool_result["exception"] + ) + if tool_result["exception"] + else "" + ), attachments, ) ) @@ -3927,12 +3935,21 @@ def _debug_tool_call(_, tool_call, tool_result): output += attachments click.echo( click.style( - textwrap.indent(output, " ") + "\n", + textwrap.indent(output, " ") + ("\n" if not tool_result.exception else ""), fg="green", bold=True, ), err=True, ) + if tool_result.exception: + click.echo( + click.style( + " Exception: {}".format(tool_result.exception), + fg="red", + bold=True, + ), + err=True, + ) def _approve_tool_call(_, tool_call): diff --git a/llm/migrations.py b/llm/migrations.py index 54cb777..f2ca046 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -413,3 +413,8 @@ def m020_tool_results_attachments(db): ), pk=("tool_result_id", "attachment_id"), ) + + +@migration +def m021_tool_results_exception(db): + db["tool_results"].add_column("exception", str) diff --git a/llm/models.py b/llm/models.py index 0842503..9afbe97 100644 --- a/llm/models.py +++ b/llm/models.py @@ -273,6 +273,7 @@ class ToolResult: attachments: List[Attachment] = field(default_factory=list) tool_call_id: Optional[str] = None instance: Optional[Toolbox] = None + exception: Optional[Exception] = None @dataclass @@ -284,9 +285,9 @@ class ToolOutput: ToolDef = Union[Tool, Toolbox, Callable[..., Any]] -BeforeCallSync = Callable[[Tool, ToolCall], None] +BeforeCallSync = Callable[[Optional[Tool], ToolCall], None] AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None] -BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] +BeforeCallAsync = Callable[[Optional[Tool], ToolCall], Union[None, Awaitable[None]]] AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] @@ -925,6 +926,16 @@ class _BaseResponse: "output": tool_result.output, "tool_call_id": tool_result.tool_call_id, "instance_id": instance_id, + "exception": ( + ( + "{}: {}".format( + tool_result.exception.__class__.__name__, + str(tool_result.exception), + ) + ) + if tool_result.exception + else None + ), } ) .last_pk @@ -989,16 +1000,8 @@ class Response(_BaseResponse): tools_by_name = {tool.name: tool for tool in self.prompt.tools} for tool_call in self.tool_calls(): tool = tools_by_name.get(tool_call.name) - if tool is None: - tool_results.append( - ToolResult( - name=tool_call.name, - output='Error: tool "{}" does not exist'.format(tool_call.name), - tool_call_id=tool_call.tool_call_id, - ) - ) - continue - + # Tool could be None if the tool was not found in the prompt tools, + # but we still call the before_call method: if before_call: cb_result = before_call(tool, tool_call) if inspect.isawaitable(cb_result): @@ -1007,12 +1010,25 @@ class Response(_BaseResponse): "Please use an async chain/response or a synchronous callback." ) + if tool is None: + msg = 'tool "{}" does not exist'.format(tool_call.name) + tool_results.append( + ToolResult( + name=tool_call.name, + output="Error: " + msg, + tool_call_id=tool_call.tool_call_id, + exception=KeyError(msg), + ) + ) + continue + if not tool.implementation: raise ValueError( "No implementation available for tool: {}".format(tool_call.name) ) attachments = [] + exception = None try: if asyncio.iscoroutinefunction(tool.implementation): @@ -1028,6 +1044,7 @@ class Response(_BaseResponse): result = json.dumps(result, default=repr) except Exception as ex: result = f"Error: {ex}" + exception = ex tool_result_obj = ToolResult( name=tool_call.name, @@ -1035,6 +1052,7 @@ class Response(_BaseResponse): attachments=attachments, tool_call_id=tool_call.tool_call_id, instance=_get_instance(tool.implementation), + exception=exception, ) if after_call: @@ -1161,13 +1179,15 @@ class AsyncResponse(_BaseResponse): for idx, tc in enumerate(tool_calls_list): tool = tools_by_name.get(tc.name) - if tool is None: - raise CancelToolCall(f"Unknown tool: {tc.name}") - if not tool.implementation: - raise CancelToolCall(f"No implementation for tool: {tc.name}") + exception: Optional[Exception] = None - # If it's an async implementation, wrap it - if inspect.iscoroutinefunction(tool.implementation): + if tool is None: + output = f'Error: tool "{tc.name}" does not exist' + exception = KeyError(tc.name) + elif not tool.implementation: + output = f'Error: tool "{tc.name}" has no implementation' + exception = KeyError(tc.name) + elif inspect.iscoroutinefunction(tool.implementation): async def run_async(tc=tc, tool=tool, idx=idx): # before_call inside the task @@ -1176,7 +1196,9 @@ class AsyncResponse(_BaseResponse): if inspect.isawaitable(cb): await cb + exception = None attachments = [] + try: result = await tool.implementation(**tc.arguments) if isinstance(result, ToolOutput): @@ -1189,6 +1211,7 @@ class AsyncResponse(_BaseResponse): ) except Exception as ex: output = f"Error: {ex}" + exception = ex tr = ToolResult( name=tc.name, @@ -1196,10 +1219,11 @@ class AsyncResponse(_BaseResponse): attachments=attachments, tool_call_id=tc.tool_call_id, instance=_get_instance(tool.implementation), + exception=exception, ) # after_call inside the task - if after_call: + if tool is not None and after_call: cb2 = after_call(tool, tc, tr) if inspect.isawaitable(cb2): await cb2 @@ -1215,34 +1239,44 @@ class AsyncResponse(_BaseResponse): if inspect.isawaitable(cb): await cb + exception = None attachments = [] - try: - res = tool.implementation(**tc.arguments) - if inspect.isawaitable(res): - res = await res - if isinstance(res, ToolOutput): - attachments.extend(res.attachments) - res = res.output - output = ( - res if isinstance(res, str) else json.dumps(res, default=repr) + + if tool is None: + output = f'Error: tool "{tc.name}" does not exist' + exception = KeyError(tc.name) + else: + try: + res = tool.implementation(**tc.arguments) + if inspect.isawaitable(res): + res = await res + if isinstance(res, ToolOutput): + attachments.extend(res.attachments) + res = res.output + output = ( + res + if isinstance(res, str) + else json.dumps(res, default=repr) + ) + except Exception as ex: + output = f"Error: {ex}" + exception = ex + + tr = ToolResult( + name=tc.name, + output=output, + attachments=attachments, + tool_call_id=tc.tool_call_id, + instance=_get_instance(tool.implementation), + exception=exception, ) - except Exception as ex: - output = f"Error: {ex}" - tr = ToolResult( - name=tc.name, - output=output, - attachments=attachments, - tool_call_id=tc.tool_call_id, - instance=_get_instance(tool.implementation), - ) + if tool is not None and after_call: + cb2 = after_call(tool, tc, tr) + if inspect.isawaitable(cb2): + await cb2 - if after_call: - cb2 = after_call(tool, tc, tr) - if inspect.isawaitable(cb2): - await cb2 - - indexed_results.append((idx, tr)) + indexed_results.append((idx, tr)) # Await all async tasks in parallel if async_tasks: diff --git a/tests/test_plugins.py b/tests/test_plugins.py index c47495b..d232279 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -459,12 +459,12 @@ def test_register_tools(tmpdir, logs_db): ('{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}', "[]"), ( "", - '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "attachments": []}]', + '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "exception": null, "attachments": []}]', ), ('{"tool_calls": [{"name": "upper", "arguments": {"text": "two"}}]}', "[]"), ( "", - '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "attachments": []}]', + '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "exception": null, "attachments": []}]', ), ( '{"tool_calls": [{"name": "upper", "arguments": {"text": "three"}}]}', @@ -472,7 +472,7 @@ def test_register_tools(tmpdir, logs_db): ), ( "", - '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "attachments": []}]', + '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "exception": null, "attachments": []}]', ), ) # Test the --td option diff --git a/tests/test_tools.py b/tests/test_tools.py index 293d2db..e803a49 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -385,3 +385,50 @@ async def test_tool_conversation_settings_async(): await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text() assert len(before_collected) == 2 assert len(after_collected) == 2 + + +ERROR_FUNCTION = """ +def trigger_error(msg: str): + raise Exception(msg) +""" + + +@pytest.mark.parametrize("async_", [False]) # Add True again +def test_tool_errors(async_): + # https://github.com/simonw/llm/issues/1107 + runner = CliRunner() + result = runner.invoke( + cli.cli, + ( + [ + "-m", + "echo", + "--functions", + ERROR_FUNCTION, + json.dumps( + { + "tool_calls": [ + {"name": "trigger_error", "arguments": {"msg": "Error!"}} + ] + } + ), + ] + + (["--async"] if async_ else []) + ), + ) + assert result.exit_code == 0 + assert '"output": "Error: Error!"' in result.output + # llm logs --json output + log_json_result = runner.invoke(cli.cli, ["logs", "--json", "-c"]) + assert log_json_result.exit_code == 0 + log_data = json.loads(log_json_result.output) + assert len(log_data) == 2 + assert log_data[1]["tool_results"][0]["exception"] == "Exception: Error!" + # llm logs -c output + log_text_result = runner.invoke(cli.cli, ["logs", "-c"]) + assert log_text_result.exit_code == 0 + assert ( + "- **trigger_error**: `None`
\n" + " Error: Error!
\n" + " **Error**: Exception: Error!\n" + ) in log_text_result.output