Better handling of before_call cancellation, closes #1148

This commit is contained in:
Simon Willison 2025-06-01 18:36:55 -07:00
parent d96ae4ed8d
commit 3a96d52895
3 changed files with 122 additions and 12 deletions

View file

@ -178,6 +178,8 @@ response = model.chain(
)
print(response.text())
```
If you raise `llm.CancelToolCall` in the `before_call` function the model will be informed that the tool call was cancelled.
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
```python
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):

View file

@ -1013,12 +1013,23 @@ class Response(_BaseResponse):
# Tool could be None if the tool was not found in the prompt tools,
# but we still call the before_call method:
if before_call:
cb_result = before_call(tool, tool_call)
if inspect.isawaitable(cb_result):
raise TypeError(
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
"Please use an async chain/response or a synchronous callback."
try:
cb_result = before_call(tool, tool_call)
if inspect.isawaitable(cb_result):
raise TypeError(
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
"Please use an async chain/response or a synchronous callback."
)
except CancelToolCall as ex:
tool_results.append(
ToolResult(
name=tool_call.name,
output="Cancelled: " + str(ex),
tool_call_id=tool_call.tool_call_id,
exception=ex,
)
)
continue
if tool is None:
msg = 'tool "{}" does not exist'.format(tool_call.name)
@ -1202,9 +1213,17 @@ class AsyncResponse(_BaseResponse):
async def run_async(tc=tc, tool=tool, idx=idx):
# before_call inside the task
if before_call:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
try:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
except CancelToolCall as ex:
return idx, ToolResult(
name=tc.name,
output="Cancelled: " + str(ex),
tool_call_id=tc.tool_call_id,
exception=ex,
)
exception = None
attachments = []
@ -1245,9 +1264,23 @@ class AsyncResponse(_BaseResponse):
else:
# Sync implementation: do hooks and call inline
if before_call:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
try:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
except CancelToolCall as ex:
indexed_results.append(
(
idx,
ToolResult(
name=tc.name,
output="Cancelled: " + str(ex),
tool_call_id=tc.tool_call_id,
exception=ex,
),
)
)
continue
exception = None
attachments = []

View file

@ -3,7 +3,7 @@ from click.testing import CliRunner
from importlib.metadata import version
import json
import llm
from llm import cli
from llm import cli, CancelToolCall
from llm.migrations import migrate
from llm.tools import llm_time
import os
@ -432,3 +432,78 @@ def test_tool_errors(async_):
" Error: Error!<br>\n"
" **Error**: Exception: Error!\n"
) in log_text_result.output
def test_chain_sync_cancel_only_first_of_two():
model = llm.get_model("echo")
def t1() -> str:
return "ran1"
def t2() -> str:
return "ran2"
def before(tool, tool_call):
if tool.name == "t1":
raise CancelToolCall("skip1")
# allow t2
return None
calls = [
{"name": "t1"},
{"name": "t2"},
]
payload = json.dumps({"tool_calls": calls})
chain = model.chain(payload, tools=[t1, t2], before_call=before)
_ = chain.text()
# second response has two results
second = chain._responses[1]
results = second.prompt.tool_results
assert len(results) == 2
# first cancelled, second executed
assert results[0].name == "t1"
assert results[0].output == "Cancelled: skip1"
assert isinstance(results[0].exception, CancelToolCall)
assert results[1].name == "t2"
assert results[1].output == "ran2"
assert results[1].exception is None
# 2c async equivalent
@pytest.mark.asyncio
async def test_chain_async_cancel_only_first_of_two():
async_model = llm.get_async_model("echo")
def t1() -> str:
return "ran1"
async def t2() -> str:
return "ran2"
async def before(tool, tool_call):
if tool.name == "t1":
raise CancelToolCall("skip1")
return None
calls = [
{"name": "t1"},
{"name": "t2"},
]
payload = json.dumps({"tool_calls": calls})
chain = async_model.chain(payload, tools=[t1, t2], before_call=before)
_ = await chain.text()
second = chain._responses[1]
results = second.prompt.tool_results
assert len(results) == 2
assert results[0].name == "t1"
assert results[0].output == "Cancelled: skip1"
assert isinstance(results[0].exception, CancelToolCall)
assert results[1].name == "t2"
assert results[1].output == "ran2"
assert results[1].exception is None