model.chain() method, tools=[func1, func2]

Also fixed mypy errors, but had to drop AsyncResponse for the moment.

Refs https://github.com/simonw/llm/issues/937#issuecomment-2870479021
This commit is contained in:
Simon Willison 2025-05-11 18:53:42 -07:00
parent 619daa6ff2
commit 8a3c461e46

View file

@ -165,6 +165,9 @@ class Tool:
)
ToolDef = Union[Tool, Callable[..., Any]]
@dataclass
class ToolCall:
name: str
@ -218,7 +221,7 @@ class Prompt:
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
schema = schema.model_json_schema()
self.schema = schema
self.tools = tools or []
self.tools = _wrap_tools(tools or [])
self.tool_results = tool_results or []
self.options = options or {}
@ -236,6 +239,18 @@ class Prompt:
return "\n\n".join(bits)
def _wrap_tools(tools: List[ToolDef]) -> List[Tool]:
wrapped_tools = []
for tool in tools:
if isinstance(tool, Tool):
wrapped_tools.append(tool)
elif callable(tool):
wrapped_tools.append(Tool.function(tool))
else:
raise ValueError(f"Invalid tool: {tool}")
return wrapped_tools
@dataclass
class _BaseConversation:
model: "_BaseModel"
@ -894,7 +909,7 @@ class _BaseChainResponse:
prompt: Prompt,
model: "_BaseModel",
stream: bool,
conversation: _BaseConversation = None,
conversation: _BaseConversation,
key: Optional[str] = None,
details: bool = False,
chain_limit: int = 5,
@ -904,11 +919,16 @@ class _BaseChainResponse:
self.stream = stream
self._key = key
self._details = details
self._responses: List[Union[Response, AsyncResponse]] = []
# self._responses: List[Union[Response, AsyncResponse]] = []
self._responses: List[Response] = []
self.conversation = conversation
self.chain_limit = chain_limit
def responses(self, details=False) -> Iterator[Union[Response, AsyncResponse, str]]:
def responses(
self, details=False
) -> Iterator[
Union[Response, str]
]: # Iterator[Union[Response, AsyncResponse, str]]:
prompt = self.prompt
count = 0
response = Response(
@ -938,7 +958,7 @@ class _BaseChainResponse:
)
implementation = tools_by_name.get(tool_calls[0].name)
if not implementation:
# TODO: send this as an error instead
# TODO: send this as an error instead?
raise ValueError(f"Tool {tool_call.name} not found")
try:
result = implementation(**tool_call.arguments)
@ -975,7 +995,10 @@ class _BaseChainResponse:
def __iter__(self) -> Iterator[str]:
for response in self.responses():
yield from response
if isinstance(response, str):
yield response
else:
yield from response
def details(self) -> Iterator[str]:
for thing in self.responses(details=True):
@ -1106,6 +1129,35 @@ class _Model(_BaseModel):
key=key,
)
def chain(
self,
prompt: Optional[str] = None,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
details: bool = False,
**options,
):
return self.conversation().chain(
prompt=prompt,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
stream=stream,
schema=schema,
tools=tools,
tool_results=tool_results,
details=details,
**options,
)
class Model(_Model):
@abstractmethod