From ea4df34563d4f4f3be91ee62d27ceda331f22330 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 12 May 2025 20:47:15 -0700 Subject: [PATCH] New response.excute_tool_calls(), refs #1007 --- llm/models.py | 90 +++++++++++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/llm/models.py b/llm/models.py index 25963da..a23a2fc 100644 --- a/llm/models.py +++ b/llm/models.py @@ -192,6 +192,10 @@ class ToolResult: tool_call_id: Optional[str] = None +class CancelToolCall(Exception): + pass + + @dataclass class Prompt: _prompt: Optional[str] @@ -324,7 +328,7 @@ class Conversation(_BaseConversation): tool_results: Optional[List[ToolResult]] = None, details: bool = False, key: Optional[str] = None, - **options, + options: Optional[dict] = None, ): self.model._validate_attachments(attachments) return ChainResponse( @@ -338,7 +342,7 @@ class Conversation(_BaseConversation): tool_results=tool_results, system_fragments=system_fragments, model=self.model, - options=self.model.Options(**options), + options=self.model.Options(**(options or {})), ), model=self.model, stream=stream, @@ -482,6 +486,42 @@ class _BaseResponse: def add_tool_call(self, tool_call: ToolCall): self._tool_calls.append(tool_call) + def execute_tool_calls( + self, + *, + before_call: Optional[Callable[[Tool, ToolCall], None]] = None, + after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + ) -> List[ToolResult]: + tool_results = [] + tools_by_name = {tool.name: tool for tool in self.prompt.tools} + # TODO: make this work async + for tool_call in self.tool_calls(): + tool = tools_by_name.get(tool_call.name) + if tool is None: + raise CancelToolCall("Unknown tool: {}".format(tool_call.name)) + if before_call: + # This may raise CancelToolCall: + before_call(tool, tool_call) + if not tool.implementation: + raise CancelToolCall( + "No implementation available for tool: {}".format(tool_call.name) + ) + try: + result = tool.implementation(**tool_call.arguments) + if not isinstance(result, str): + result = json.dumps(result, default=repr) + except Exception as ex: + result = f"Error: {ex}" + tool_result = ToolResult( + name=tool_call.name, + output=result, + tool_call_id=tool_call.tool_call_id, + ) + if after_call: + after_call(tool, tool_call, tool_result) + tool_results.append(tool_result) + return tool_results + def set_usage( self, *, @@ -1023,7 +1063,7 @@ class _BaseChainResponse: conversation: _BaseConversation, key: Optional[str] = None, details: bool = False, - chain_limit: int = 5, + chain_limit: int = 10, ): self.prompt = prompt self.model = model @@ -1055,44 +1095,8 @@ class _BaseChainResponse: self._responses.append(response) if count > self.chain_limit: raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ") - tool_calls = response.tool_calls() - if not tool_calls: - return - tools_by_name = { - tool.name: tool.implementation for tool in response.prompt.tools - } - tool_results = [] - for tool_call in tool_calls: - if details: - yield "\nTool call requested: {}({})\n".format( - tool_call.name, - ", ".join( - f"{k}={repr(v)}" for k, v in tool_call.arguments.items() - ), - ) - implementation = tools_by_name.get(tool_calls[0].name) - if not implementation: - # TODO: send this as an error instead? - raise ValueError(f"Tool {tool_call.name} not found") - try: - result = implementation(**tool_call.arguments) - if not isinstance(result, str): - result = json.dumps(result, default=repr) - if details: - yield f"\n{result}\n" - except Exception as ex: - if details: - yield f"\nError:\ns{ex}\n" - result = f"Error: {ex}" - tool_results.append( - ToolResult( - name=tool_call.name, - output=result, - tool_call_id=tool_call.tool_call_id, - ) - ) - if not tool_results: - break + # This could raise llm.CancelToolCall: + tool_results = response.execute_tool_calls() response = Response( Prompt( "", @@ -1256,7 +1260,7 @@ class _Model(_BaseModel): tools: Optional[List[Tool]] = None, tool_results: Optional[List[ToolResult]] = None, details: bool = False, - **options, + options: Optional[dict] = None, ): return self.conversation().chain( prompt=prompt, @@ -1269,7 +1273,7 @@ class _Model(_BaseModel): tools=tools, tool_results=tool_results, details=details, - **options, + options=options, )