mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
New response.excute_tool_calls(), refs #1007
This commit is contained in:
parent
387f89d88b
commit
ea4df34563
1 changed files with 47 additions and 43 deletions
|
|
@ -192,6 +192,10 @@ class ToolResult:
|
|||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
class CancelToolCall(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prompt:
|
||||
_prompt: Optional[str]
|
||||
|
|
@ -324,7 +328,7 @@ class Conversation(_BaseConversation):
|
|||
tool_results: Optional[List[ToolResult]] = None,
|
||||
details: bool = False,
|
||||
key: Optional[str] = None,
|
||||
**options,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
self.model._validate_attachments(attachments)
|
||||
return ChainResponse(
|
||||
|
|
@ -338,7 +342,7 @@ class Conversation(_BaseConversation):
|
|||
tool_results=tool_results,
|
||||
system_fragments=system_fragments,
|
||||
model=self.model,
|
||||
options=self.model.Options(**options),
|
||||
options=self.model.Options(**(options or {})),
|
||||
),
|
||||
model=self.model,
|
||||
stream=stream,
|
||||
|
|
@ -482,6 +486,42 @@ class _BaseResponse:
|
|||
def add_tool_call(self, tool_call: ToolCall):
|
||||
self._tool_calls.append(tool_call)
|
||||
|
||||
def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_results = []
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
# TODO: make this work async
|
||||
for tool_call in self.tool_calls():
|
||||
tool = tools_by_name.get(tool_call.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
|
||||
if before_call:
|
||||
# This may raise CancelToolCall:
|
||||
before_call(tool, tool_call)
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(
|
||||
"No implementation available for tool: {}".format(tool_call.name)
|
||||
)
|
||||
try:
|
||||
result = tool.implementation(**tool_call.arguments)
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
result = f"Error: {ex}"
|
||||
tool_result = ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
if after_call:
|
||||
after_call(tool, tool_call, tool_result)
|
||||
tool_results.append(tool_result)
|
||||
return tool_results
|
||||
|
||||
def set_usage(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -1023,7 +1063,7 @@ class _BaseChainResponse:
|
|||
conversation: _BaseConversation,
|
||||
key: Optional[str] = None,
|
||||
details: bool = False,
|
||||
chain_limit: int = 5,
|
||||
chain_limit: int = 10,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
|
|
@ -1055,44 +1095,8 @@ class _BaseChainResponse:
|
|||
self._responses.append(response)
|
||||
if count > self.chain_limit:
|
||||
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
|
||||
tool_calls = response.tool_calls()
|
||||
if not tool_calls:
|
||||
return
|
||||
tools_by_name = {
|
||||
tool.name: tool.implementation for tool in response.prompt.tools
|
||||
}
|
||||
tool_results = []
|
||||
for tool_call in tool_calls:
|
||||
if details:
|
||||
yield "\nTool call requested: {}({})\n".format(
|
||||
tool_call.name,
|
||||
", ".join(
|
||||
f"{k}={repr(v)}" for k, v in tool_call.arguments.items()
|
||||
),
|
||||
)
|
||||
implementation = tools_by_name.get(tool_calls[0].name)
|
||||
if not implementation:
|
||||
# TODO: send this as an error instead?
|
||||
raise ValueError(f"Tool {tool_call.name} not found")
|
||||
try:
|
||||
result = implementation(**tool_call.arguments)
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
if details:
|
||||
yield f"\n{result}\n"
|
||||
except Exception as ex:
|
||||
if details:
|
||||
yield f"\nError:\ns{ex}\n"
|
||||
result = f"Error: {ex}"
|
||||
tool_results.append(
|
||||
ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
)
|
||||
if not tool_results:
|
||||
break
|
||||
# This could raise llm.CancelToolCall:
|
||||
tool_results = response.execute_tool_calls()
|
||||
response = Response(
|
||||
Prompt(
|
||||
"",
|
||||
|
|
@ -1256,7 +1260,7 @@ class _Model(_BaseModel):
|
|||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
details: bool = False,
|
||||
**options,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
return self.conversation().chain(
|
||||
prompt=prompt,
|
||||
|
|
@ -1269,7 +1273,7 @@ class _Model(_BaseModel):
|
|||
tools=tools,
|
||||
tool_results=tool_results,
|
||||
details=details,
|
||||
**options,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue