New response.excute_tool_calls(), refs #1007

This commit is contained in:
Simon Willison 2025-05-12 20:47:15 -07:00
parent 387f89d88b
commit ea4df34563

View file

@ -192,6 +192,10 @@ class ToolResult:
tool_call_id: Optional[str] = None
class CancelToolCall(Exception):
pass
@dataclass
class Prompt:
_prompt: Optional[str]
@ -324,7 +328,7 @@ class Conversation(_BaseConversation):
tool_results: Optional[List[ToolResult]] = None,
details: bool = False,
key: Optional[str] = None,
**options,
options: Optional[dict] = None,
):
self.model._validate_attachments(attachments)
return ChainResponse(
@ -338,7 +342,7 @@ class Conversation(_BaseConversation):
tool_results=tool_results,
system_fragments=system_fragments,
model=self.model,
options=self.model.Options(**options),
options=self.model.Options(**(options or {})),
),
model=self.model,
stream=stream,
@ -482,6 +486,42 @@ class _BaseResponse:
def add_tool_call(self, tool_call: ToolCall):
self._tool_calls.append(tool_call)
def execute_tool_calls(
self,
*,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# TODO: make this work async
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
if tool is None:
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
if before_call:
# This may raise CancelToolCall:
before_call(tool, tool_call)
if not tool.implementation:
raise CancelToolCall(
"No implementation available for tool: {}".format(tool_call.name)
)
try:
result = tool.implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
result = f"Error: {ex}"
tool_result = ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
if after_call:
after_call(tool, tool_call, tool_result)
tool_results.append(tool_result)
return tool_results
def set_usage(
self,
*,
@ -1023,7 +1063,7 @@ class _BaseChainResponse:
conversation: _BaseConversation,
key: Optional[str] = None,
details: bool = False,
chain_limit: int = 5,
chain_limit: int = 10,
):
self.prompt = prompt
self.model = model
@ -1055,44 +1095,8 @@ class _BaseChainResponse:
self._responses.append(response)
if count > self.chain_limit:
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
tool_calls = response.tool_calls()
if not tool_calls:
return
tools_by_name = {
tool.name: tool.implementation for tool in response.prompt.tools
}
tool_results = []
for tool_call in tool_calls:
if details:
yield "\nTool call requested: {}({})\n".format(
tool_call.name,
", ".join(
f"{k}={repr(v)}" for k, v in tool_call.arguments.items()
),
)
implementation = tools_by_name.get(tool_calls[0].name)
if not implementation:
# TODO: send this as an error instead?
raise ValueError(f"Tool {tool_call.name} not found")
try:
result = implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
if details:
yield f"\n{result}\n"
except Exception as ex:
if details:
yield f"\nError:\ns{ex}\n"
result = f"Error: {ex}"
tool_results.append(
ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
)
if not tool_results:
break
# This could raise llm.CancelToolCall:
tool_results = response.execute_tool_calls()
response = Response(
Prompt(
"",
@ -1256,7 +1260,7 @@ class _Model(_BaseModel):
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
details: bool = False,
**options,
options: Optional[dict] = None,
):
return self.conversation().chain(
prompt=prompt,
@ -1269,7 +1273,7 @@ class _Model(_BaseModel):
tools=tools,
tool_results=tool_results,
details=details,
**options,
options=options,
)