diff --git a/docs/python-api.md b/docs/python-api.md index ccb724d..7809555 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -178,6 +178,8 @@ response = model.chain( ) print(response.text()) ``` +If you raise `llm.CancelToolCall` in the `before_call` function the model will be informed that the tool call was cancelled. + The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example: ```python def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult): diff --git a/llm/models.py b/llm/models.py index bebe0a8..c0ad882 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1013,12 +1013,23 @@ class Response(_BaseResponse): # Tool could be None if the tool was not found in the prompt tools, # but we still call the before_call method: if before_call: - cb_result = before_call(tool, tool_call) - if inspect.isawaitable(cb_result): - raise TypeError( - "Asynchronous 'before_call' callback provided to a synchronous tool execution context. " - "Please use an async chain/response or a synchronous callback." + try: + cb_result = before_call(tool, tool_call) + if inspect.isawaitable(cb_result): + raise TypeError( + "Asynchronous 'before_call' callback provided to a synchronous tool execution context. " + "Please use an async chain/response or a synchronous callback." + ) + except CancelToolCall as ex: + tool_results.append( + ToolResult( + name=tool_call.name, + output="Cancelled: " + str(ex), + tool_call_id=tool_call.tool_call_id, + exception=ex, + ) ) + continue if tool is None: msg = 'tool "{}" does not exist'.format(tool_call.name) @@ -1202,9 +1213,17 @@ class AsyncResponse(_BaseResponse): async def run_async(tc=tc, tool=tool, idx=idx): # before_call inside the task if before_call: - cb = before_call(tool, tc) - if inspect.isawaitable(cb): - await cb + try: + cb = before_call(tool, tc) + if inspect.isawaitable(cb): + await cb + except CancelToolCall as ex: + return idx, ToolResult( + name=tc.name, + output="Cancelled: " + str(ex), + tool_call_id=tc.tool_call_id, + exception=ex, + ) exception = None attachments = [] @@ -1245,9 +1264,23 @@ class AsyncResponse(_BaseResponse): else: # Sync implementation: do hooks and call inline if before_call: - cb = before_call(tool, tc) - if inspect.isawaitable(cb): - await cb + try: + cb = before_call(tool, tc) + if inspect.isawaitable(cb): + await cb + except CancelToolCall as ex: + indexed_results.append( + ( + idx, + ToolResult( + name=tc.name, + output="Cancelled: " + str(ex), + tool_call_id=tc.tool_call_id, + exception=ex, + ), + ) + ) + continue exception = None attachments = [] diff --git a/tests/test_tools.py b/tests/test_tools.py index 6fd7070..ad144aa 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -3,7 +3,7 @@ from click.testing import CliRunner from importlib.metadata import version import json import llm -from llm import cli +from llm import cli, CancelToolCall from llm.migrations import migrate from llm.tools import llm_time import os @@ -432,3 +432,78 @@ def test_tool_errors(async_): " Error: Error!
\n" " **Error**: Exception: Error!\n" ) in log_text_result.output + + +def test_chain_sync_cancel_only_first_of_two(): + model = llm.get_model("echo") + + def t1() -> str: + return "ran1" + + def t2() -> str: + return "ran2" + + def before(tool, tool_call): + if tool.name == "t1": + raise CancelToolCall("skip1") + # allow t2 + return None + + calls = [ + {"name": "t1"}, + {"name": "t2"}, + ] + payload = json.dumps({"tool_calls": calls}) + chain = model.chain(payload, tools=[t1, t2], before_call=before) + _ = chain.text() + + # second response has two results + second = chain._responses[1] + results = second.prompt.tool_results + assert len(results) == 2 + + # first cancelled, second executed + assert results[0].name == "t1" + assert results[0].output == "Cancelled: skip1" + assert isinstance(results[0].exception, CancelToolCall) + + assert results[1].name == "t2" + assert results[1].output == "ran2" + assert results[1].exception is None + + +# 2c async equivalent +@pytest.mark.asyncio +async def test_chain_async_cancel_only_first_of_two(): + async_model = llm.get_async_model("echo") + + def t1() -> str: + return "ran1" + + async def t2() -> str: + return "ran2" + + async def before(tool, tool_call): + if tool.name == "t1": + raise CancelToolCall("skip1") + return None + + calls = [ + {"name": "t1"}, + {"name": "t2"}, + ] + payload = json.dumps({"tool_calls": calls}) + chain = async_model.chain(payload, tools=[t1, t2], before_call=before) + _ = await chain.text() + + second = chain._responses[1] + results = second.prompt.tool_results + assert len(results) == 2 + + assert results[0].name == "t1" + assert results[0].output == "Cancelled: skip1" + assert isinstance(results[0].exception, CancelToolCall) + + assert results[1].name == "t2" + assert results[1].output == "ran2" + assert results[1].exception is None