mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Better handling of before_call cancellation, closes #1148
This commit is contained in:
parent
d96ae4ed8d
commit
3a96d52895
3 changed files with 122 additions and 12 deletions
|
|
@ -178,6 +178,8 @@ response = model.chain(
|
|||
)
|
||||
print(response.text())
|
||||
```
|
||||
If you raise `llm.CancelToolCall` in the `before_call` function the model will be informed that the tool call was cancelled.
|
||||
|
||||
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
|
||||
```python
|
||||
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):
|
||||
|
|
|
|||
|
|
@ -1013,12 +1013,23 @@ class Response(_BaseResponse):
|
|||
# Tool could be None if the tool was not found in the prompt tools,
|
||||
# but we still call the before_call method:
|
||||
if before_call:
|
||||
cb_result = before_call(tool, tool_call)
|
||||
if inspect.isawaitable(cb_result):
|
||||
raise TypeError(
|
||||
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
|
||||
"Please use an async chain/response or a synchronous callback."
|
||||
try:
|
||||
cb_result = before_call(tool, tool_call)
|
||||
if inspect.isawaitable(cb_result):
|
||||
raise TypeError(
|
||||
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
|
||||
"Please use an async chain/response or a synchronous callback."
|
||||
)
|
||||
except CancelToolCall as ex:
|
||||
tool_results.append(
|
||||
ToolResult(
|
||||
name=tool_call.name,
|
||||
output="Cancelled: " + str(ex),
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
exception=ex,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if tool is None:
|
||||
msg = 'tool "{}" does not exist'.format(tool_call.name)
|
||||
|
|
@ -1202,9 +1213,17 @@ class AsyncResponse(_BaseResponse):
|
|||
async def run_async(tc=tc, tool=tool, idx=idx):
|
||||
# before_call inside the task
|
||||
if before_call:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
try:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
except CancelToolCall as ex:
|
||||
return idx, ToolResult(
|
||||
name=tc.name,
|
||||
output="Cancelled: " + str(ex),
|
||||
tool_call_id=tc.tool_call_id,
|
||||
exception=ex,
|
||||
)
|
||||
|
||||
exception = None
|
||||
attachments = []
|
||||
|
|
@ -1245,9 +1264,23 @@ class AsyncResponse(_BaseResponse):
|
|||
else:
|
||||
# Sync implementation: do hooks and call inline
|
||||
if before_call:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
try:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
except CancelToolCall as ex:
|
||||
indexed_results.append(
|
||||
(
|
||||
idx,
|
||||
ToolResult(
|
||||
name=tc.name,
|
||||
output="Cancelled: " + str(ex),
|
||||
tool_call_id=tc.tool_call_id,
|
||||
exception=ex,
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
exception = None
|
||||
attachments = []
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from click.testing import CliRunner
|
|||
from importlib.metadata import version
|
||||
import json
|
||||
import llm
|
||||
from llm import cli
|
||||
from llm import cli, CancelToolCall
|
||||
from llm.migrations import migrate
|
||||
from llm.tools import llm_time
|
||||
import os
|
||||
|
|
@ -432,3 +432,78 @@ def test_tool_errors(async_):
|
|||
" Error: Error!<br>\n"
|
||||
" **Error**: Exception: Error!\n"
|
||||
) in log_text_result.output
|
||||
|
||||
|
||||
def test_chain_sync_cancel_only_first_of_two():
|
||||
model = llm.get_model("echo")
|
||||
|
||||
def t1() -> str:
|
||||
return "ran1"
|
||||
|
||||
def t2() -> str:
|
||||
return "ran2"
|
||||
|
||||
def before(tool, tool_call):
|
||||
if tool.name == "t1":
|
||||
raise CancelToolCall("skip1")
|
||||
# allow t2
|
||||
return None
|
||||
|
||||
calls = [
|
||||
{"name": "t1"},
|
||||
{"name": "t2"},
|
||||
]
|
||||
payload = json.dumps({"tool_calls": calls})
|
||||
chain = model.chain(payload, tools=[t1, t2], before_call=before)
|
||||
_ = chain.text()
|
||||
|
||||
# second response has two results
|
||||
second = chain._responses[1]
|
||||
results = second.prompt.tool_results
|
||||
assert len(results) == 2
|
||||
|
||||
# first cancelled, second executed
|
||||
assert results[0].name == "t1"
|
||||
assert results[0].output == "Cancelled: skip1"
|
||||
assert isinstance(results[0].exception, CancelToolCall)
|
||||
|
||||
assert results[1].name == "t2"
|
||||
assert results[1].output == "ran2"
|
||||
assert results[1].exception is None
|
||||
|
||||
|
||||
# 2c async equivalent
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_async_cancel_only_first_of_two():
|
||||
async_model = llm.get_async_model("echo")
|
||||
|
||||
def t1() -> str:
|
||||
return "ran1"
|
||||
|
||||
async def t2() -> str:
|
||||
return "ran2"
|
||||
|
||||
async def before(tool, tool_call):
|
||||
if tool.name == "t1":
|
||||
raise CancelToolCall("skip1")
|
||||
return None
|
||||
|
||||
calls = [
|
||||
{"name": "t1"},
|
||||
{"name": "t2"},
|
||||
]
|
||||
payload = json.dumps({"tool_calls": calls})
|
||||
chain = async_model.chain(payload, tools=[t1, t2], before_call=before)
|
||||
_ = await chain.text()
|
||||
|
||||
second = chain._responses[1]
|
||||
results = second.prompt.tool_results
|
||||
assert len(results) == 2
|
||||
|
||||
assert results[0].name == "t1"
|
||||
assert results[0].output == "Cancelled: skip1"
|
||||
assert isinstance(results[0].exception, CancelToolCall)
|
||||
|
||||
assert results[1].name == "t2"
|
||||
assert results[1].output == "ran2"
|
||||
assert results[1].exception is None
|
||||
|
|
|
|||
Loading…
Reference in a new issue