Initial Conversation.chain() and ChainResponse

Refs https://github.com/simonw/llm/issues/937#issuecomment-2870365809
This commit is contained in:
Simon Willison 2025-05-11 17:14:57 -07:00
parent c990578934
commit 3b37854c26

View file

@ -285,6 +285,42 @@ class Conversation(_BaseConversation):
key=key,
)
def chain(
self,
prompt: Optional[str] = None,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
details: bool = False,
**options,
):
self.model._validate_attachments(attachments)
return ChainResponse(
Prompt(
prompt,
fragments=fragments,
attachments=attachments,
system=system,
schema=schema,
tools=tools,
tool_results=tool_results,
system_fragments=system_fragments,
model=self.model,
options=self.model.Options(**options),
),
model=self.model,
stream=stream,
conversation=self,
key=options.pop("key", None),
details=details,
)
@classmethod
def from_row(cls, row):
from llm import get_model
@ -846,6 +882,96 @@ class AsyncResponse(_BaseResponse):
return "<AsyncResponse prompt='{}' text='{}'>".format(self.prompt.prompt, text)
class _BaseChainResponse:
prompt: "Prompt"
stream: bool
conversation: Optional["_BaseConversation"] = None
_key: Optional[str] = None
_responses: List[Union["Response", "AsyncResponse"]]
def __init__(
self,
prompt: Prompt,
model: "_BaseModel",
stream: bool,
conversation: _BaseConversation = None,
key: Optional[str] = None,
details: bool = False,
chain_limit: int = 10,
):
self.prompt = prompt
self.model = model
self.stream = stream
self._key = key
self._details = details
self._responses: List[Union[Response, AsyncResponse]] = []
self.conversation = conversation
self.chain_limit = chain_limit
def responses(self) -> Iterator[Union[Response, AsyncResponse]]:
prompt = self.prompt
response = Response(
prompt,
self.model,
self.stream,
key=self._key,
conversation=self.conversation,
)
while response:
yield response
tool_calls = response.tool_calls()
if not tool_calls:
return
tools_by_name = {
tool.name: tool.implementation for tool in response.prompt.tools
}
tool_results = []
for tool_call in tool_calls:
implementation = tools_by_name.get(tool_calls[0].name)
if not implementation:
# TODO: send this as an error instead
raise ValueError(f"Tool {tool_call.name} not found")
try:
result = implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
raise ValueError(f"Error executing tool {tool_call.name}: {ex}")
tool_results.append(
ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
)
if not tool_results:
break
response = Response(
Prompt(
"",
self.model,
tools=response.prompt.tools,
tool_results=tool_results,
options=self.prompt.options,
),
self.model,
stream=self.stream,
key=self._key,
conversation=self.conversation,
)
def __iter__(self) -> Iterator[str]:
for response in self.responses():
yield from response
def text(self) -> str:
return "".join(self)
class ChainResponse(_BaseChainResponse):
"Know how to chain multiple responses e.g. for tool calls"
class Options(BaseModel):
model_config = ConfigDict(extra="forbid")