mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
model.chain() method, tools=[func1, func2]
Also fixed mypy errors, but had to drop AsyncResponse for the moment. Refs https://github.com/simonw/llm/issues/937#issuecomment-2870479021
This commit is contained in:
parent
619daa6ff2
commit
8a3c461e46
1 changed files with 58 additions and 6 deletions
|
|
@ -165,6 +165,9 @@ class Tool:
|
|||
)
|
||||
|
||||
|
||||
ToolDef = Union[Tool, Callable[..., Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
name: str
|
||||
|
|
@ -218,7 +221,7 @@ class Prompt:
|
|||
if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel):
|
||||
schema = schema.model_json_schema()
|
||||
self.schema = schema
|
||||
self.tools = tools or []
|
||||
self.tools = _wrap_tools(tools or [])
|
||||
self.tool_results = tool_results or []
|
||||
self.options = options or {}
|
||||
|
||||
|
|
@ -236,6 +239,18 @@ class Prompt:
|
|||
return "\n\n".join(bits)
|
||||
|
||||
|
||||
def _wrap_tools(tools: List[ToolDef]) -> List[Tool]:
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
wrapped_tools.append(tool)
|
||||
elif callable(tool):
|
||||
wrapped_tools.append(Tool.function(tool))
|
||||
else:
|
||||
raise ValueError(f"Invalid tool: {tool}")
|
||||
return wrapped_tools
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaseConversation:
|
||||
model: "_BaseModel"
|
||||
|
|
@ -894,7 +909,7 @@ class _BaseChainResponse:
|
|||
prompt: Prompt,
|
||||
model: "_BaseModel",
|
||||
stream: bool,
|
||||
conversation: _BaseConversation = None,
|
||||
conversation: _BaseConversation,
|
||||
key: Optional[str] = None,
|
||||
details: bool = False,
|
||||
chain_limit: int = 5,
|
||||
|
|
@ -904,11 +919,16 @@ class _BaseChainResponse:
|
|||
self.stream = stream
|
||||
self._key = key
|
||||
self._details = details
|
||||
self._responses: List[Union[Response, AsyncResponse]] = []
|
||||
# self._responses: List[Union[Response, AsyncResponse]] = []
|
||||
self._responses: List[Response] = []
|
||||
self.conversation = conversation
|
||||
self.chain_limit = chain_limit
|
||||
|
||||
def responses(self, details=False) -> Iterator[Union[Response, AsyncResponse, str]]:
|
||||
def responses(
|
||||
self, details=False
|
||||
) -> Iterator[
|
||||
Union[Response, str]
|
||||
]: # Iterator[Union[Response, AsyncResponse, str]]:
|
||||
prompt = self.prompt
|
||||
count = 0
|
||||
response = Response(
|
||||
|
|
@ -938,7 +958,7 @@ class _BaseChainResponse:
|
|||
)
|
||||
implementation = tools_by_name.get(tool_calls[0].name)
|
||||
if not implementation:
|
||||
# TODO: send this as an error instead
|
||||
# TODO: send this as an error instead?
|
||||
raise ValueError(f"Tool {tool_call.name} not found")
|
||||
try:
|
||||
result = implementation(**tool_call.arguments)
|
||||
|
|
@ -975,7 +995,10 @@ class _BaseChainResponse:
|
|||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for response in self.responses():
|
||||
yield from response
|
||||
if isinstance(response, str):
|
||||
yield response
|
||||
else:
|
||||
yield from response
|
||||
|
||||
def details(self) -> Iterator[str]:
|
||||
for thing in self.responses(details=True):
|
||||
|
|
@ -1106,6 +1129,35 @@ class _Model(_BaseModel):
|
|||
key=key,
|
||||
)
|
||||
|
||||
def chain(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
*,
|
||||
fragments: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
system_fragments: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
details: bool = False,
|
||||
**options,
|
||||
):
|
||||
return self.conversation().chain(
|
||||
prompt=prompt,
|
||||
fragments=fragments,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
system_fragments=system_fragments,
|
||||
stream=stream,
|
||||
schema=schema,
|
||||
tools=tools,
|
||||
tool_results=tool_results,
|
||||
details=details,
|
||||
**options,
|
||||
)
|
||||
|
||||
|
||||
class Model(_Model):
|
||||
@abstractmethod
|
||||
|
|
|
|||
Loading…
Reference in a new issue