diff --git a/docs/python-api.md b/docs/python-api.md index 4493116..b6571a7 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -148,6 +148,46 @@ for response in chain.responses(): print(chunk, end="", flush=True) ``` +(python-api-tools-debug-hooks)= + +#### Tool debugging hooks + +Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call. + +The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example: +```python +import llm + +def upper(text: str) -> str: + "Convert text to uppercase." + return text.upper() + +def before_call(tool: llm.Tool, tool_call: llm.ToolCall): + print(f"About to call tool {tool.name} with arguments {tool_call.arguments}") + if tool.name == "upper" and "bad" in repr(tool_call.arguments): + raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'") + +model = llm.get_model("gpt-4.1-mini") +response = model.chain( + "Convert panda to upper and badger to upper", + tools=[upper], + before_call=before_call, +) +print(response.text()) +``` +The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example: +```python +def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult): + print(f"Tool {tool.name} called with arguments {tool_call.arguments} returned {tool_result.output}") + +response = model.chain( + "Convert panda to upper and badger to upper", + tools=[upper], + after_call=after_call, +) +print(response.text()) +``` + (python-api-tools-attachments)= #### Tools can return attachments @@ -575,6 +615,8 @@ print(conversation.chain( "Same with pangolin" ).text()) ``` +The `before_call=` and `after_call=` parameters {ref}`described above ` can be passed directly to the `model.conversation()` method to set those options for all chained prompts in that conversation. + (python-api-listing-models)= diff --git a/llm/cli.py b/llm/cli.py index 0af4d0c..61d06a6 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1080,6 +1080,11 @@ def chat( # Ensure it can see the API key conversation.model = model + if tools_debug: + conversation.after_call = _debug_tool_call + if tools_approve: + conversation.before_call = _approve_tool_call + # Validate options validated_options = get_model_options(model.model_id) if options: @@ -1100,10 +1105,6 @@ def chat( if tool_functions: kwargs["chain_limit"] = chain_limit - if tools_debug: - kwargs["after_call"] = _debug_tool_call - if tools_approve: - kwargs["before_call"] = _approve_tool_call kwargs["tools"] = tool_functions should_stream = model.can_stream and not no_stream diff --git a/llm/models.py b/llm/models.py index 9391ff4..e971033 100644 --- a/llm/models.py +++ b/llm/models.py @@ -259,9 +259,6 @@ class Toolbox: return methods -ToolDef = Union[Tool, Toolbox, Callable[..., Any]] - - @dataclass class ToolCall: name: str @@ -286,6 +283,13 @@ class ToolOutput: attachments: List[Attachment] = field(default_factory=list) +ToolDef = Union[Tool, Toolbox, Callable[..., Any]] +BeforeCallSync = Callable[[Tool, ToolCall], None] +AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None] +BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] +AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + + class CancelToolCall(Exception): pass @@ -368,6 +372,7 @@ class _BaseConversation: name: Optional[str] = None responses: List["_BaseResponse"] = field(default_factory=list) tools: Optional[List[Tool]] = None + chain_limit: Optional[int] = None @classmethod @abstractmethod @@ -377,6 +382,9 @@ class _BaseConversation: @dataclass class Conversation(_BaseConversation): + before_call: Optional[BeforeCallSync] = None + after_call: Optional[AfterCallSync] = None + def prompt( self, prompt: Optional[str] = None, @@ -424,8 +432,8 @@ class Conversation(_BaseConversation): tools: Optional[List[Tool]] = None, tool_results: Optional[List[ToolResult]] = None, chain_limit: Optional[int] = None, - before_call: Optional[Callable[[Tool, ToolCall], None]] = None, - after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + before_call: Optional[BeforeCallSync] = None, + after_call: Optional[AfterCallSync] = None, key: Optional[str] = None, options: Optional[dict] = None, ) -> "ChainResponse": @@ -447,9 +455,9 @@ class Conversation(_BaseConversation): stream=stream, conversation=self, key=key, - before_call=before_call, - after_call=after_call, - chain_limit=chain_limit, + before_call=before_call or self.before_call, + after_call=after_call or self.after_call, + chain_limit=chain_limit if chain_limit is not None else self.chain_limit, ) @classmethod @@ -470,6 +478,9 @@ class Conversation(_BaseConversation): @dataclass class AsyncConversation(_BaseConversation): + before_call: Optional[BeforeCallAsync] = None + after_call: Optional[AfterCallAsync] = None + def chain( self, prompt: Optional[str] = None, @@ -483,12 +494,8 @@ class AsyncConversation(_BaseConversation): tools: Optional[List[Tool]] = None, tool_results: Optional[List[ToolResult]] = None, chain_limit: Optional[int] = None, - before_call: Optional[ - Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] - ] = None, - after_call: Optional[ - Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] - ] = None, + before_call: Optional[BeforeCallAsync] = None, + after_call: Optional[AfterCallAsync] = None, key: Optional[str] = None, options: Optional[dict] = None, ) -> "AsyncChainResponse": @@ -510,9 +517,9 @@ class AsyncConversation(_BaseConversation): stream=stream, conversation=self, key=key, - before_call=before_call, - after_call=after_call, - chain_limit=chain_limit, + before_call=before_call or self.before_call, + after_call=after_call or self.after_call, + chain_limit=chain_limit if chain_limit is not None else self.chain_limit, ) def prompt( @@ -975,12 +982,8 @@ class Response(_BaseResponse): def execute_tool_calls( self, *, - before_call: Optional[ - Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] - ] = None, - after_call: Optional[ - Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] - ] = None, + before_call: Optional[BeforeCallSync] = None, + after_call: Optional[AfterCallSync] = None, ) -> List[ToolResult]: tool_results = [] tools_by_name = {tool.name: tool for tool in self.prompt.tools} @@ -1147,12 +1150,8 @@ class AsyncResponse(_BaseResponse): async def execute_tool_calls( self, *, - before_call: Optional[ - Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] - ] = None, - after_call: Optional[ - Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] - ] = None, + before_call: Optional[BeforeCallAsync] = None, + after_call: Optional[AfterCallAsync] = None, ) -> List[ToolResult]: tool_calls_list = await self.tool_calls() tools_by_name = {tool.name: tool for tool in self.prompt.tools} @@ -1437,12 +1436,8 @@ class _BaseChainResponse: conversation: _BaseConversation, key: Optional[str] = None, chain_limit: Optional[int] = 10, - before_call: Optional[ - Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] - ] = None, - after_call: Optional[ - Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] - ] = None, + before_call: Optional[Union[BeforeCallSync, BeforeCallAsync]] = None, + after_call: Optional[Union[AfterCallSync, AfterCallAsync]] = None, ): self.prompt = prompt self.model = model @@ -1467,6 +1462,8 @@ class _BaseChainResponse: class ChainResponse(_BaseChainResponse): _responses: List["Response"] + before_call: Optional[BeforeCallSync] = None + after_call: Optional[AfterCallSync] = None def responses(self) -> Iterator[Response]: prompt = self.prompt @@ -1521,6 +1518,8 @@ class ChainResponse(_BaseChainResponse): class AsyncChainResponse(_BaseChainResponse): _responses: List["AsyncResponse"] + before_call: Optional[BeforeCallAsync] = None + after_call: Optional[AfterCallAsync] = None async def responses(self) -> AsyncIterator[AsyncResponse]: prompt = self.prompt @@ -1656,8 +1655,20 @@ class _BaseModel(ABC, _get_key_mixin): class _Model(_BaseModel): - def conversation(self, tools: Optional[List[Tool]] = None) -> Conversation: - return Conversation(model=self, tools=tools) + def conversation( + self, + tools: Optional[List[Tool]] = None, + before_call: Optional[BeforeCallSync] = None, + after_call: Optional[AfterCallSync] = None, + chain_limit: Optional[int] = None, + ) -> Conversation: + return Conversation( + model=self, + tools=tools, + before_call=before_call, + after_call=after_call, + chain_limit=chain_limit, + ) def prompt( self, @@ -1705,8 +1716,8 @@ class _Model(_BaseModel): schema: Optional[Union[dict, type[BaseModel]]] = None, tools: Optional[List[Tool]] = None, tool_results: Optional[List[ToolResult]] = None, - before_call: Optional[Callable[[Tool, ToolCall], None]] = None, - after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + before_call: Optional[BeforeCallSync] = None, + after_call: Optional[AfterCallSync] = None, key: Optional[str] = None, options: Optional[dict] = None, ) -> ChainResponse: @@ -1753,8 +1764,20 @@ class KeyModel(_Model): class _AsyncModel(_BaseModel): - def conversation(self, tools: Optional[List[Tool]] = None) -> AsyncConversation: - return AsyncConversation(model=self, tools=tools) + def conversation( + self, + tools: Optional[List[Tool]] = None, + before_call: Optional[BeforeCallAsync] = None, + after_call: Optional[AfterCallAsync] = None, + chain_limit: Optional[int] = None, + ) -> AsyncConversation: + return AsyncConversation( + model=self, + tools=tools, + before_call=before_call, + after_call=after_call, + chain_limit=chain_limit, + ) def prompt( self, @@ -1802,12 +1825,8 @@ class _AsyncModel(_BaseModel): schema: Optional[Union[dict, type[BaseModel]]] = None, tools: Optional[List[Tool]] = None, tool_results: Optional[List[ToolResult]] = None, - before_call: Optional[ - Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] - ] = None, - after_call: Optional[ - Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] - ] = None, + before_call: Optional[BeforeCallAsync] = None, + after_call: Optional[AfterCallAsync] = None, key: Optional[str] = None, options: Optional[dict] = None, ) -> AsyncChainResponse: diff --git a/tests/test_tools.py b/tests/test_tools.py index ecb61b1..293d2db 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -343,3 +343,45 @@ async def test_async_tool_returning_attachment(): output = await chain_response.text() assert '"type": "image/png"' in output assert '"output": "Output"' in output + + +def test_tool_conversation_settings(): + model = llm.get_model("echo") + before_collected = [] + after_collected = [] + + def before(*args): + before_collected.append(args) + + def after(*args): + after_collected.append(args) + + conversation = model.conversation( + tools=[llm_time], before_call=before, after_call=after + ) + # Run two things + conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text() + conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text() + assert len(before_collected) == 2 + assert len(after_collected) == 2 + + +@pytest.mark.asyncio +async def test_tool_conversation_settings_async(): + model = llm.get_async_model("echo") + before_collected = [] + after_collected = [] + + async def before(*args): + before_collected.append(args) + + async def after(*args): + after_collected.append(args) + + conversation = model.conversation( + tools=[llm_time], before_call=before, after_call=after + ) + await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text() + await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text() + assert len(before_collected) == 2 + assert len(after_collected) == 2