mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-14 08:43:13 +00:00
Initial Conversation.chain() and ChainResponse
Refs https://github.com/simonw/llm/issues/937#issuecomment-2870365809
This commit is contained in:
parent
c990578934
commit
3b37854c26
1 changed files with 126 additions and 0 deletions
126
llm/models.py
126
llm/models.py
|
|
@ -285,6 +285,42 @@ class Conversation(_BaseConversation):
|
|||
key=key,
|
||||
)
|
||||
|
||||
def chain(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
*,
|
||||
fragments: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
system_fragments: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
details: bool = False,
|
||||
**options,
|
||||
):
|
||||
self.model._validate_attachments(attachments)
|
||||
return ChainResponse(
|
||||
Prompt(
|
||||
prompt,
|
||||
fragments=fragments,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
schema=schema,
|
||||
tools=tools,
|
||||
tool_results=tool_results,
|
||||
system_fragments=system_fragments,
|
||||
model=self.model,
|
||||
options=self.model.Options(**options),
|
||||
),
|
||||
model=self.model,
|
||||
stream=stream,
|
||||
conversation=self,
|
||||
key=options.pop("key", None),
|
||||
details=details,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_model
|
||||
|
|
@ -846,6 +882,96 @@ class AsyncResponse(_BaseResponse):
|
|||
return "<AsyncResponse prompt='{}' text='{}'>".format(self.prompt.prompt, text)
|
||||
|
||||
|
||||
class _BaseChainResponse:
|
||||
prompt: "Prompt"
|
||||
stream: bool
|
||||
conversation: Optional["_BaseConversation"] = None
|
||||
_key: Optional[str] = None
|
||||
_responses: List[Union["Response", "AsyncResponse"]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
model: "_BaseModel",
|
||||
stream: bool,
|
||||
conversation: _BaseConversation = None,
|
||||
key: Optional[str] = None,
|
||||
details: bool = False,
|
||||
chain_limit: int = 10,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
self.stream = stream
|
||||
self._key = key
|
||||
self._details = details
|
||||
self._responses: List[Union[Response, AsyncResponse]] = []
|
||||
self.conversation = conversation
|
||||
self.chain_limit = chain_limit
|
||||
|
||||
def responses(self) -> Iterator[Union[Response, AsyncResponse]]:
|
||||
prompt = self.prompt
|
||||
response = Response(
|
||||
prompt,
|
||||
self.model,
|
||||
self.stream,
|
||||
key=self._key,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
while response:
|
||||
yield response
|
||||
tool_calls = response.tool_calls()
|
||||
if not tool_calls:
|
||||
return
|
||||
tools_by_name = {
|
||||
tool.name: tool.implementation for tool in response.prompt.tools
|
||||
}
|
||||
tool_results = []
|
||||
for tool_call in tool_calls:
|
||||
implementation = tools_by_name.get(tool_calls[0].name)
|
||||
if not implementation:
|
||||
# TODO: send this as an error instead
|
||||
raise ValueError(f"Tool {tool_call.name} not found")
|
||||
try:
|
||||
result = implementation(**tool_call.arguments)
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Error executing tool {tool_call.name}: {ex}")
|
||||
tool_results.append(
|
||||
ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
)
|
||||
if not tool_results:
|
||||
break
|
||||
response = Response(
|
||||
Prompt(
|
||||
"",
|
||||
self.model,
|
||||
tools=response.prompt.tools,
|
||||
tool_results=tool_results,
|
||||
options=self.prompt.options,
|
||||
),
|
||||
self.model,
|
||||
stream=self.stream,
|
||||
key=self._key,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for response in self.responses():
|
||||
yield from response
|
||||
|
||||
def text(self) -> str:
|
||||
return "".join(self)
|
||||
|
||||
|
||||
class ChainResponse(_BaseChainResponse):
|
||||
"Know how to chain multiple responses e.g. for tool calls"
|
||||
|
||||
|
||||
class Options(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue