chain_limit/before_call/after_call for conversations

* chain_limit/before_call/after_call for conversations, closes #1088
* Docs for before_call/after_call including for model.conversation
This commit is contained in:
Simon Willison 2025-06-01 12:00:29 -07:00 committed by GitHub
parent b5d1c5ee90
commit ed64fc3362
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 155 additions and 51 deletions

View file

@ -148,6 +148,46 @@ for response in chain.responses():
print(chunk, end="", flush=True)
```
(python-api-tools-debug-hooks)=
#### Tool debugging hooks
Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call.
The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example:
```python
import llm
def upper(text: str) -> str:
"Convert text to uppercase."
return text.upper()
def before_call(tool: llm.Tool, tool_call: llm.ToolCall):
print(f"About to call tool {tool.name} with arguments {tool_call.arguments}")
if tool.name == "upper" and "bad" in repr(tool_call.arguments):
raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'")
model = llm.get_model("gpt-4.1-mini")
response = model.chain(
"Convert panda to upper and badger to upper",
tools=[upper],
before_call=before_call,
)
print(response.text())
```
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
```python
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):
print(f"Tool {tool.name} called with arguments {tool_call.arguments} returned {tool_result.output}")
response = model.chain(
"Convert panda to upper and badger to upper",
tools=[upper],
after_call=after_call,
)
print(response.text())
```
(python-api-tools-attachments)=
#### Tools can return attachments
@ -575,6 +615,8 @@ print(conversation.chain(
"Same with pangolin"
).text())
```
The `before_call=` and `after_call=` parameters {ref}`described above <python-api-tools-debug-hooks>` can be passed directly to the `model.conversation()` method to set those options for all chained prompts in that conversation.
(python-api-listing-models)=

View file

@ -1080,6 +1080,11 @@ def chat(
# Ensure it can see the API key
conversation.model = model
if tools_debug:
conversation.after_call = _debug_tool_call
if tools_approve:
conversation.before_call = _approve_tool_call
# Validate options
validated_options = get_model_options(model.model_id)
if options:
@ -1100,10 +1105,6 @@ def chat(
if tool_functions:
kwargs["chain_limit"] = chain_limit
if tools_debug:
kwargs["after_call"] = _debug_tool_call
if tools_approve:
kwargs["before_call"] = _approve_tool_call
kwargs["tools"] = tool_functions
should_stream = model.can_stream and not no_stream

View file

@ -259,9 +259,6 @@ class Toolbox:
return methods
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
@dataclass
class ToolCall:
name: str
@ -286,6 +283,13 @@ class ToolOutput:
attachments: List[Attachment] = field(default_factory=list)
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
BeforeCallSync = Callable[[Tool, ToolCall], None]
AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None]
BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
class CancelToolCall(Exception):
pass
@ -368,6 +372,7 @@ class _BaseConversation:
name: Optional[str] = None
responses: List["_BaseResponse"] = field(default_factory=list)
tools: Optional[List[Tool]] = None
chain_limit: Optional[int] = None
@classmethod
@abstractmethod
@ -377,6 +382,9 @@ class _BaseConversation:
@dataclass
class Conversation(_BaseConversation):
before_call: Optional[BeforeCallSync] = None
after_call: Optional[AfterCallSync] = None
def prompt(
self,
prompt: Optional[str] = None,
@ -424,8 +432,8 @@ class Conversation(_BaseConversation):
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
chain_limit: Optional[int] = None,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
before_call: Optional[BeforeCallSync] = None,
after_call: Optional[AfterCallSync] = None,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> "ChainResponse":
@ -447,9 +455,9 @@ class Conversation(_BaseConversation):
stream=stream,
conversation=self,
key=key,
before_call=before_call,
after_call=after_call,
chain_limit=chain_limit,
before_call=before_call or self.before_call,
after_call=after_call or self.after_call,
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
)
@classmethod
@ -470,6 +478,9 @@ class Conversation(_BaseConversation):
@dataclass
class AsyncConversation(_BaseConversation):
before_call: Optional[BeforeCallAsync] = None
after_call: Optional[AfterCallAsync] = None
def chain(
self,
prompt: Optional[str] = None,
@ -483,12 +494,8 @@ class AsyncConversation(_BaseConversation):
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
chain_limit: Optional[int] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
before_call: Optional[BeforeCallAsync] = None,
after_call: Optional[AfterCallAsync] = None,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> "AsyncChainResponse":
@ -510,9 +517,9 @@ class AsyncConversation(_BaseConversation):
stream=stream,
conversation=self,
key=key,
before_call=before_call,
after_call=after_call,
chain_limit=chain_limit,
before_call=before_call or self.before_call,
after_call=after_call or self.after_call,
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
)
def prompt(
@ -975,12 +982,8 @@ class Response(_BaseResponse):
def execute_tool_calls(
self,
*,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
before_call: Optional[BeforeCallSync] = None,
after_call: Optional[AfterCallSync] = None,
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
@ -1147,12 +1150,8 @@ class AsyncResponse(_BaseResponse):
async def execute_tool_calls(
self,
*,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
before_call: Optional[BeforeCallAsync] = None,
after_call: Optional[AfterCallAsync] = None,
) -> List[ToolResult]:
tool_calls_list = await self.tool_calls()
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
@ -1437,12 +1436,8 @@ class _BaseChainResponse:
conversation: _BaseConversation,
key: Optional[str] = None,
chain_limit: Optional[int] = 10,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
before_call: Optional[Union[BeforeCallSync, BeforeCallAsync]] = None,
after_call: Optional[Union[AfterCallSync, AfterCallAsync]] = None,
):
self.prompt = prompt
self.model = model
@ -1467,6 +1462,8 @@ class _BaseChainResponse:
class ChainResponse(_BaseChainResponse):
_responses: List["Response"]
before_call: Optional[BeforeCallSync] = None
after_call: Optional[AfterCallSync] = None
def responses(self) -> Iterator[Response]:
prompt = self.prompt
@ -1521,6 +1518,8 @@ class ChainResponse(_BaseChainResponse):
class AsyncChainResponse(_BaseChainResponse):
_responses: List["AsyncResponse"]
before_call: Optional[BeforeCallAsync] = None
after_call: Optional[AfterCallAsync] = None
async def responses(self) -> AsyncIterator[AsyncResponse]:
prompt = self.prompt
@ -1656,8 +1655,20 @@ class _BaseModel(ABC, _get_key_mixin):
class _Model(_BaseModel):
def conversation(self, tools: Optional[List[Tool]] = None) -> Conversation:
return Conversation(model=self, tools=tools)
def conversation(
self,
tools: Optional[List[Tool]] = None,
before_call: Optional[BeforeCallSync] = None,
after_call: Optional[AfterCallSync] = None,
chain_limit: Optional[int] = None,
) -> Conversation:
return Conversation(
model=self,
tools=tools,
before_call=before_call,
after_call=after_call,
chain_limit=chain_limit,
)
def prompt(
self,
@ -1705,8 +1716,8 @@ class _Model(_BaseModel):
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
before_call: Optional[BeforeCallSync] = None,
after_call: Optional[AfterCallSync] = None,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> ChainResponse:
@ -1753,8 +1764,20 @@ class KeyModel(_Model):
class _AsyncModel(_BaseModel):
def conversation(self, tools: Optional[List[Tool]] = None) -> AsyncConversation:
return AsyncConversation(model=self, tools=tools)
def conversation(
self,
tools: Optional[List[Tool]] = None,
before_call: Optional[BeforeCallAsync] = None,
after_call: Optional[AfterCallAsync] = None,
chain_limit: Optional[int] = None,
) -> AsyncConversation:
return AsyncConversation(
model=self,
tools=tools,
before_call=before_call,
after_call=after_call,
chain_limit=chain_limit,
)
def prompt(
self,
@ -1802,12 +1825,8 @@ class _AsyncModel(_BaseModel):
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
before_call: Optional[BeforeCallAsync] = None,
after_call: Optional[AfterCallAsync] = None,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> AsyncChainResponse:

View file

@ -343,3 +343,45 @@ async def test_async_tool_returning_attachment():
output = await chain_response.text()
assert '"type": "image/png"' in output
assert '"output": "Output"' in output
def test_tool_conversation_settings():
model = llm.get_model("echo")
before_collected = []
after_collected = []
def before(*args):
before_collected.append(args)
def after(*args):
after_collected.append(args)
conversation = model.conversation(
tools=[llm_time], before_call=before, after_call=after
)
# Run two things
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
assert len(before_collected) == 2
assert len(after_collected) == 2
@pytest.mark.asyncio
async def test_tool_conversation_settings_async():
model = llm.get_async_model("echo")
before_collected = []
after_collected = []
async def before(*args):
before_collected.append(args)
async def after(*args):
after_collected.append(args)
conversation = model.conversation(
tools=[llm_time], before_call=before, after_call=after
)
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
assert len(before_collected) == 2
assert len(after_collected) == 2