From 3b37854c26aa8bdd49b74837bd4fd2758ef832c4 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 11 May 2025 17:14:57 -0700 Subject: [PATCH] Initial Conversation.chain() and ChainResponse Refs https://github.com/simonw/llm/issues/937#issuecomment-2870365809 --- llm/models.py | 126 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/llm/models.py b/llm/models.py index 61656a0..4f170bf 100644 --- a/llm/models.py +++ b/llm/models.py @@ -285,6 +285,42 @@ class Conversation(_BaseConversation): key=key, ) + def chain( + self, + prompt: Optional[str] = None, + *, + fragments: Optional[List[str]] = None, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, + stream: bool = True, + schema: Optional[Union[dict, type[BaseModel]]] = None, + tools: Optional[List[Tool]] = None, + tool_results: Optional[List[ToolResult]] = None, + details: bool = False, + **options, + ): + self.model._validate_attachments(attachments) + return ChainResponse( + Prompt( + prompt, + fragments=fragments, + attachments=attachments, + system=system, + schema=schema, + tools=tools, + tool_results=tool_results, + system_fragments=system_fragments, + model=self.model, + options=self.model.Options(**options), + ), + model=self.model, + stream=stream, + conversation=self, + key=options.pop("key", None), + details=details, + ) + @classmethod def from_row(cls, row): from llm import get_model @@ -846,6 +882,96 @@ class AsyncResponse(_BaseResponse): return "".format(self.prompt.prompt, text) +class _BaseChainResponse: + prompt: "Prompt" + stream: bool + conversation: Optional["_BaseConversation"] = None + _key: Optional[str] = None + _responses: List[Union["Response", "AsyncResponse"]] + + def __init__( + self, + prompt: Prompt, + model: "_BaseModel", + stream: bool, + conversation: _BaseConversation = None, + key: Optional[str] = None, + details: bool = False, + chain_limit: int = 10, + ): + self.prompt = prompt + self.model = model + self.stream = stream + self._key = key + self._details = details + self._responses: List[Union[Response, AsyncResponse]] = [] + self.conversation = conversation + self.chain_limit = chain_limit + + def responses(self) -> Iterator[Union[Response, AsyncResponse]]: + prompt = self.prompt + response = Response( + prompt, + self.model, + self.stream, + key=self._key, + conversation=self.conversation, + ) + while response: + yield response + tool_calls = response.tool_calls() + if not tool_calls: + return + tools_by_name = { + tool.name: tool.implementation for tool in response.prompt.tools + } + tool_results = [] + for tool_call in tool_calls: + implementation = tools_by_name.get(tool_calls[0].name) + if not implementation: + # TODO: send this as an error instead + raise ValueError(f"Tool {tool_call.name} not found") + try: + result = implementation(**tool_call.arguments) + if not isinstance(result, str): + result = json.dumps(result, default=repr) + except Exception as ex: + raise ValueError(f"Error executing tool {tool_call.name}: {ex}") + tool_results.append( + ToolResult( + name=tool_call.name, + output=result, + tool_call_id=tool_call.tool_call_id, + ) + ) + if not tool_results: + break + response = Response( + Prompt( + "", + self.model, + tools=response.prompt.tools, + tool_results=tool_results, + options=self.prompt.options, + ), + self.model, + stream=self.stream, + key=self._key, + conversation=self.conversation, + ) + + def __iter__(self) -> Iterator[str]: + for response in self.responses(): + yield from response + + def text(self) -> str: + return "".join(self) + + +class ChainResponse(_BaseChainResponse): + "Know how to chain multiple responses e.g. for tool calls" + + class Options(BaseModel): model_config = ConfigDict(extra="forbid")