From 8a3c461e46d6f1cb4bc43c03052b674cafa82513 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 11 May 2025 18:53:42 -0700 Subject: [PATCH] model.chain() method, tools=[func1, func2] Also fixed mypy errors, but had to drop AsyncResponse for the moment. Refs https://github.com/simonw/llm/issues/937#issuecomment-2870479021 --- llm/models.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/llm/models.py b/llm/models.py index 1c2216a..0c0f6ab 100644 --- a/llm/models.py +++ b/llm/models.py @@ -165,6 +165,9 @@ class Tool: ) +ToolDef = Union[Tool, Callable[..., Any]] + + @dataclass class ToolCall: name: str @@ -218,7 +221,7 @@ class Prompt: if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel): schema = schema.model_json_schema() self.schema = schema - self.tools = tools or [] + self.tools = _wrap_tools(tools or []) self.tool_results = tool_results or [] self.options = options or {} @@ -236,6 +239,18 @@ class Prompt: return "\n\n".join(bits) +def _wrap_tools(tools: List[ToolDef]) -> List[Tool]: + wrapped_tools = [] + for tool in tools: + if isinstance(tool, Tool): + wrapped_tools.append(tool) + elif callable(tool): + wrapped_tools.append(Tool.function(tool)) + else: + raise ValueError(f"Invalid tool: {tool}") + return wrapped_tools + + @dataclass class _BaseConversation: model: "_BaseModel" @@ -894,7 +909,7 @@ class _BaseChainResponse: prompt: Prompt, model: "_BaseModel", stream: bool, - conversation: _BaseConversation = None, + conversation: _BaseConversation, key: Optional[str] = None, details: bool = False, chain_limit: int = 5, @@ -904,11 +919,16 @@ class _BaseChainResponse: self.stream = stream self._key = key self._details = details - self._responses: List[Union[Response, AsyncResponse]] = [] + # self._responses: List[Union[Response, AsyncResponse]] = [] + self._responses: List[Response] = [] self.conversation = conversation self.chain_limit = chain_limit - def responses(self, details=False) -> Iterator[Union[Response, AsyncResponse, str]]: + def responses( + self, details=False + ) -> Iterator[ + Union[Response, str] + ]: # Iterator[Union[Response, AsyncResponse, str]]: prompt = self.prompt count = 0 response = Response( @@ -938,7 +958,7 @@ class _BaseChainResponse: ) implementation = tools_by_name.get(tool_calls[0].name) if not implementation: - # TODO: send this as an error instead + # TODO: send this as an error instead? raise ValueError(f"Tool {tool_call.name} not found") try: result = implementation(**tool_call.arguments) @@ -975,7 +995,10 @@ class _BaseChainResponse: def __iter__(self) -> Iterator[str]: for response in self.responses(): - yield from response + if isinstance(response, str): + yield response + else: + yield from response def details(self) -> Iterator[str]: for thing in self.responses(details=True): @@ -1106,6 +1129,35 @@ class _Model(_BaseModel): key=key, ) + def chain( + self, + prompt: Optional[str] = None, + *, + fragments: Optional[List[str]] = None, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, + stream: bool = True, + schema: Optional[Union[dict, type[BaseModel]]] = None, + tools: Optional[List[Tool]] = None, + tool_results: Optional[List[ToolResult]] = None, + details: bool = False, + **options, + ): + return self.conversation().chain( + prompt=prompt, + fragments=fragments, + attachments=attachments, + system=system, + system_fragments=system_fragments, + stream=stream, + schema=schema, + tools=tools, + tool_results=tool_results, + details=details, + **options, + ) + class Model(_Model): @abstractmethod