mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Async tool support (#1063)
* Sync models can now call async tools, refs #987 * Test for async tool functions in sync context, refs #987 * Test for asyncio tools, plus test that they run in parallel * Docs for async tool usage
This commit is contained in:
parent
0ee1ba3a65
commit
3cb875fa3d
6 changed files with 561 additions and 86 deletions
|
|
@ -89,6 +89,19 @@ for chunk in model.chain(
|
|||
```
|
||||
This will stream each of the chain of responses in turn as they are generated.
|
||||
|
||||
You can access the individual responses that make up the chain using `chain.responses()`. This can be iterated over as the chain executes like this:
|
||||
|
||||
```python
|
||||
chain = model.chain(
|
||||
"Convert panda to upper",
|
||||
tools=[upper],
|
||||
)
|
||||
for response in chain.responses():
|
||||
print(response.prompt)
|
||||
for chunk in response:
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
(python-api-system-prompts)=
|
||||
|
||||
### System prompts
|
||||
|
|
@ -351,10 +364,9 @@ model = llm.get_async_model("gpt-4o")
|
|||
You can then run a prompt using `await model.prompt(...)`:
|
||||
|
||||
```python
|
||||
response = await model.prompt(
|
||||
print(await model.prompt(
|
||||
"Five surprising names for a pet pelican"
|
||||
)
|
||||
print(await response.text())
|
||||
).text())
|
||||
```
|
||||
Or use `async for chunk in ...` to stream the response as it is generated:
|
||||
```python
|
||||
|
|
@ -365,6 +377,56 @@ async for chunk in model.prompt(
|
|||
```
|
||||
This `await model.prompt()` method takes the same arguments as the synchronous `model.prompt()` method, for options and attachments and `key=` and suchlike.
|
||||
|
||||
(python-api-async-tools)=
|
||||
|
||||
### Tool functions can be sync or async
|
||||
|
||||
{ref}`Tool functions <python-api-tools>` can be both synchronous or asynchronous. The latter are defined using `async def tool_name(...)`. Either kind of function can be passed to the `tools=[...]` parameter.
|
||||
|
||||
If an `async def` function is used in a synchronous context LLM will automatically execute it in a thread pool using `asyncio.run()`. This means the following will work even in non-asynchronous Python scripts:
|
||||
|
||||
```python
|
||||
async def hello(name: str) -> str:
|
||||
"Say hello to name"
|
||||
return "Hello there " + name
|
||||
|
||||
model = llm.get_model("gpt-4.1-mini")
|
||||
chain_response = model.chain(
|
||||
"Say hello to Percival", tools=[hello]
|
||||
)
|
||||
print(chain_response.text())
|
||||
```
|
||||
|
||||
### Tool use for async models
|
||||
|
||||
Tool use is also supported for async models, using either synchronous or asynchronous tool functions. Synchronous functions will block the event loop so only use those in asynchronous context if you are certain they are extremely fast.
|
||||
|
||||
The `response.execute_tool_calls()` and `chain_response.text()` and `chain_response.responses()` methods must all be awaited when run against asynchronous models:
|
||||
|
||||
```python
|
||||
import llm
|
||||
model = llm.get_async_model("gpt-4.1")
|
||||
|
||||
def upper(string):
|
||||
"Converts string to uppercase"
|
||||
return string.upper()
|
||||
|
||||
chain = model.chain(
|
||||
"Convert panda to uppercase then pelican to uppercase",
|
||||
tools=[upper],
|
||||
after_call=print
|
||||
)
|
||||
print(await chain.text())
|
||||
```
|
||||
|
||||
To iterate over the chained response output as it arrives use `async for`:
|
||||
```python
|
||||
async for chunk in model.chain(
|
||||
"Convert panda to uppercase then pelican to uppercase",
|
||||
tools=[upper]
|
||||
):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
(python-api-conversations)=
|
||||
|
||||
## Conversations
|
||||
|
|
|
|||
|
|
@ -763,16 +763,38 @@ class AsyncChat(_Shared, AsyncKeyModel):
|
|||
**kwargs,
|
||||
)
|
||||
chunks = []
|
||||
tool_calls = {}
|
||||
async for chunk in completion:
|
||||
if chunk.usage:
|
||||
usage = chunk.usage.model_dump()
|
||||
chunks.append(chunk)
|
||||
if chunk.usage:
|
||||
usage = chunk.usage.model_dump()
|
||||
if chunk.choices and chunk.choices[0].delta:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
||||
index = tool_call.index
|
||||
if index not in tool_calls:
|
||||
tool_calls[index] = tool_call
|
||||
tool_calls[
|
||||
index
|
||||
].function.arguments += tool_call.function.arguments
|
||||
try:
|
||||
content = chunk.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
if content is not None:
|
||||
yield content
|
||||
if tool_calls:
|
||||
for value in tool_calls.values():
|
||||
# value.function looks like this:
|
||||
# ChoiceDeltaToolCallFunction(arguments='{"city":"San Francisco"}', name='get_weather')
|
||||
response.add_tool_call(
|
||||
llm.ToolCall(
|
||||
tool_call_id=value.id,
|
||||
name=value.function.name,
|
||||
arguments=json.loads(value.function.arguments),
|
||||
)
|
||||
)
|
||||
response.response_json = remove_dict_none_values(combine_chunks(chunks))
|
||||
else:
|
||||
completion = await client.chat.completions.create(
|
||||
|
|
@ -783,6 +805,14 @@ class AsyncChat(_Shared, AsyncKeyModel):
|
|||
)
|
||||
response.response_json = remove_dict_none_values(completion.model_dump())
|
||||
usage = completion.usage.model_dump()
|
||||
for tool_call in completion.choices[0].message.tool_calls or []:
|
||||
response.add_tool_call(
|
||||
llm.ToolCall(
|
||||
tool_call_id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
)
|
||||
if completion.choices[0].message.content is not None:
|
||||
yield completion.choices[0].message.content
|
||||
self.set_usage(response, usage)
|
||||
|
|
|
|||
461
llm/models.py
461
llm/models.py
|
|
@ -12,6 +12,8 @@ import time
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
|
|
@ -375,6 +377,53 @@ class Conversation(_BaseConversation):
|
|||
|
||||
@dataclass
|
||||
class AsyncConversation(_BaseConversation):
|
||||
def chain(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
*,
|
||||
fragments: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
system_fragments: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
chain_limit: Optional[int] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
details: bool = False,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> "AsyncChainResponse":
|
||||
self.model._validate_attachments(attachments)
|
||||
return AsyncChainResponse(
|
||||
Prompt(
|
||||
prompt,
|
||||
fragments=fragments,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
schema=schema,
|
||||
tools=tools,
|
||||
tool_results=tool_results,
|
||||
system_fragments=system_fragments,
|
||||
model=self.model,
|
||||
options=self.model.Options(**(options or {})),
|
||||
),
|
||||
model=self.model,
|
||||
stream=stream,
|
||||
conversation=self,
|
||||
key=key,
|
||||
details=details,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
chain_limit=chain_limit,
|
||||
)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
|
|
@ -784,37 +833,59 @@ class Response(_BaseResponse):
|
|||
def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_results = []
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
# TODO: make this work async
|
||||
for tool_call in self.tool_calls():
|
||||
tool = tools_by_name.get(tool_call.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
|
||||
|
||||
if before_call:
|
||||
# This may raise CancelToolCall:
|
||||
before_call(tool, tool_call)
|
||||
cb_result = before_call(tool, tool_call)
|
||||
if inspect.isawaitable(cb_result):
|
||||
raise TypeError(
|
||||
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
|
||||
"Please use an async chain/response or a synchronous callback."
|
||||
)
|
||||
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(
|
||||
"No implementation available for tool: {}".format(tool_call.name)
|
||||
)
|
||||
|
||||
try:
|
||||
result = tool.implementation(**tool_call.arguments)
|
||||
if asyncio.iscoroutinefunction(tool.implementation):
|
||||
result = asyncio.run(tool.implementation(**tool_call.arguments))
|
||||
else:
|
||||
result = tool.implementation(**tool_call.arguments)
|
||||
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
result = f"Error: {ex}"
|
||||
tool_result = ToolResult(
|
||||
|
||||
tool_result_obj = ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
|
||||
if after_call:
|
||||
after_call(tool, tool_call, tool_result)
|
||||
tool_results.append(tool_result)
|
||||
cb_result = after_call(tool, tool_call, tool_result_obj)
|
||||
if inspect.isawaitable(cb_result):
|
||||
raise TypeError(
|
||||
"Asynchronous 'after_call' callback provided to a synchronous tool execution context. "
|
||||
"Please use an async chain/response or a synchronous callback."
|
||||
)
|
||||
tool_results.append(tool_result_obj)
|
||||
return tool_results
|
||||
|
||||
def tool_calls(self) -> List[ToolCall]:
|
||||
|
|
@ -901,30 +972,131 @@ class AsyncResponse(_BaseResponse):
|
|||
self.done_callbacks.append(callback)
|
||||
else:
|
||||
if callable(callback):
|
||||
callback = callback(self)
|
||||
if asyncio.iscoroutine(callback):
|
||||
# Ensure we handle both sync and async callbacks correctly
|
||||
processed_callback = callback(self)
|
||||
if inspect.isawaitable(processed_callback):
|
||||
await processed_callback
|
||||
elif inspect.isawaitable(callback):
|
||||
await callback
|
||||
|
||||
async def _on_done(self):
|
||||
for callback in self.done_callbacks:
|
||||
if callable(callback):
|
||||
callback = callback(self)
|
||||
if asyncio.iscoroutine(callback):
|
||||
await callback
|
||||
for callback_func in self.done_callbacks:
|
||||
if callable(callback_func):
|
||||
processed_callback = callback_func(self)
|
||||
if inspect.isawaitable(processed_callback):
|
||||
await processed_callback
|
||||
elif inspect.isawaitable(callback_func):
|
||||
await callback_func
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_calls_list = await self.tool_calls()
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
|
||||
indexed_results: List[tuple[int, ToolResult]] = []
|
||||
async_tasks: List[asyncio.Task] = []
|
||||
|
||||
for idx, tc in enumerate(tool_calls_list):
|
||||
tool = tools_by_name.get(tc.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall(f"Unknown tool: {tc.name}")
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(f"No implementation for tool: {tc.name}")
|
||||
|
||||
# If it's an async implementation, wrap it
|
||||
if inspect.iscoroutinefunction(tool.implementation):
|
||||
|
||||
async def run_async(tc=tc, tool=tool, idx=idx):
|
||||
# before_call inside the task
|
||||
if before_call:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
|
||||
try:
|
||||
result = await tool.implementation(**tc.arguments)
|
||||
output = (
|
||||
result
|
||||
if isinstance(result, str)
|
||||
else json.dumps(result, default=repr)
|
||||
)
|
||||
except Exception as ex:
|
||||
output = f"Error: {ex}"
|
||||
|
||||
tr = ToolResult(
|
||||
name=tc.name,
|
||||
output=output,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
)
|
||||
|
||||
# after_call inside the task
|
||||
if after_call:
|
||||
cb2 = after_call(tool, tc, tr)
|
||||
if inspect.isawaitable(cb2):
|
||||
await cb2
|
||||
|
||||
return idx, tr
|
||||
|
||||
async_tasks.append(asyncio.create_task(run_async()))
|
||||
|
||||
else:
|
||||
# Sync implementation: do hooks and call inline
|
||||
if before_call:
|
||||
cb = before_call(tool, tc)
|
||||
if inspect.isawaitable(cb):
|
||||
await cb
|
||||
|
||||
try:
|
||||
res = tool.implementation(**tc.arguments)
|
||||
if inspect.isawaitable(res):
|
||||
res = await res
|
||||
output = (
|
||||
res if isinstance(res, str) else json.dumps(res, default=repr)
|
||||
)
|
||||
except Exception as ex:
|
||||
output = f"Error: {ex}"
|
||||
|
||||
tr = ToolResult(
|
||||
name=tc.name,
|
||||
output=output,
|
||||
tool_call_id=tc.tool_call_id,
|
||||
)
|
||||
|
||||
if after_call:
|
||||
cb2 = after_call(tool, tc, tr)
|
||||
if inspect.isawaitable(cb2):
|
||||
await cb2
|
||||
|
||||
indexed_results.append((idx, tr))
|
||||
|
||||
# Await all async tasks in parallel
|
||||
if async_tasks:
|
||||
indexed_results.extend(await asyncio.gather(*async_tasks))
|
||||
|
||||
# Reorder by original index
|
||||
indexed_results.sort(key=lambda x: x[0])
|
||||
return [tr for _, tr in indexed_results]
|
||||
|
||||
def __aiter__(self):
|
||||
self._start = time.monotonic()
|
||||
self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)
|
||||
if self._done:
|
||||
self._iter_chunks = list(self._chunks) # Make a copy for iteration
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> str:
|
||||
if self._done:
|
||||
if not self._chunks:
|
||||
raise StopAsyncIteration
|
||||
chunk = self._chunks.pop(0)
|
||||
if not self._chunks:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
if hasattr(self, "_iter_chunks") and self._iter_chunks:
|
||||
return self._iter_chunks.pop(0)
|
||||
raise StopAsyncIteration
|
||||
|
||||
if not hasattr(self, "_generator"):
|
||||
if isinstance(self.model, AsyncModel):
|
||||
|
|
@ -955,13 +1127,17 @@ class AsyncResponse(_BaseResponse):
|
|||
self.conversation.responses.append(self)
|
||||
self._end = time.monotonic()
|
||||
self._done = True
|
||||
if hasattr(self, "_generator"):
|
||||
del self._generator
|
||||
await self._on_done()
|
||||
raise
|
||||
|
||||
async def _force(self):
|
||||
if not self._done:
|
||||
async for _ in self:
|
||||
pass
|
||||
temp_chunks = []
|
||||
async for chunk in self:
|
||||
temp_chunks.append(chunk)
|
||||
# This should populate self._chunks
|
||||
return self
|
||||
|
||||
def text_or_raise(self) -> str:
|
||||
|
|
@ -1007,20 +1183,44 @@ class AsyncResponse(_BaseResponse):
|
|||
|
||||
async def to_sync_response(self) -> Response:
|
||||
await self._force()
|
||||
# This conversion might be tricky if the model is AsyncModel,
|
||||
# as Response expects a sync Model. For simplicity, we'll assume
|
||||
# the primary use case is data transfer after completion.
|
||||
# The model type on the new Response might need careful handling
|
||||
# if it's intended for further execution.
|
||||
# For now, let's assume self.model can be cast or is compatible.
|
||||
sync_model = self.model
|
||||
if not isinstance(self.model, (Model, KeyModel)):
|
||||
# This is a placeholder. A proper conversion or shared base might be needed
|
||||
# if the sync_response needs to be fully functional with its model.
|
||||
# For now, we pass the async model, which might limit what sync_response can do.
|
||||
pass
|
||||
|
||||
response = Response(
|
||||
self.prompt,
|
||||
self.model,
|
||||
sync_model, # This might need adjustment based on how Model/AsyncModel relate
|
||||
self.stream,
|
||||
conversation=self.conversation,
|
||||
# conversation type needs to be compatible too.
|
||||
conversation=(
|
||||
self.conversation.to_sync_conversation()
|
||||
if self.conversation
|
||||
and hasattr(self.conversation, "to_sync_conversation")
|
||||
else None
|
||||
),
|
||||
)
|
||||
response._chunks = self._chunks
|
||||
response._done = True
|
||||
response.id = self.id
|
||||
response._chunks = list(self._chunks) # Copy chunks
|
||||
response._done = self._done
|
||||
response._end = self._end
|
||||
response._start = self._start
|
||||
response._start_utcnow = self._start_utcnow
|
||||
response.input_tokens = self.input_tokens
|
||||
response.output_tokens = self.output_tokens
|
||||
response.token_details = self.token_details
|
||||
response._prompt_json = self._prompt_json
|
||||
response.response_json = self.response_json
|
||||
response._tool_calls = list(self._tool_calls)
|
||||
response.attachments = list(self.attachments)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1059,7 +1259,6 @@ class _BaseChainResponse:
|
|||
stream: bool
|
||||
conversation: Optional["_BaseConversation"] = None
|
||||
_key: Optional[str] = None
|
||||
_responses: List[Union["Response", "AsyncResponse"]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -1070,51 +1269,65 @@ class _BaseChainResponse:
|
|||
key: Optional[str] = None,
|
||||
details: bool = False,
|
||||
chain_limit: Optional[int] = 10,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
self.stream = stream
|
||||
self._key = key
|
||||
self._details = details
|
||||
# self._responses: List[Union[Response, AsyncResponse]] = []
|
||||
self._responses: List[Response] = []
|
||||
self._responses: List[Any] = []
|
||||
self.conversation = conversation
|
||||
self.chain_limit = chain_limit
|
||||
self.before_call = before_call
|
||||
self.after_call = after_call
|
||||
|
||||
def responses(
|
||||
self, details=False
|
||||
) -> Iterator[
|
||||
Union[Response, str]
|
||||
]: # Iterator[Union[Response, AsyncResponse, str]]:
|
||||
def log_to_db(self, db):
|
||||
for response in self._responses:
|
||||
if isinstance(response, AsyncResponse):
|
||||
sync_response = asyncio.run(response.to_sync_response())
|
||||
elif isinstance(response, Response):
|
||||
sync_response = response
|
||||
else:
|
||||
assert False, "Should have been a Response or AsyncResponse"
|
||||
sync_response.log_to_db(db)
|
||||
|
||||
|
||||
class ChainResponse(_BaseChainResponse):
|
||||
_responses: List["Response"]
|
||||
|
||||
def responses(self) -> Iterator[Response]:
|
||||
prompt = self.prompt
|
||||
count = 0
|
||||
response = Response(
|
||||
current_response: Optional[Response] = Response(
|
||||
prompt,
|
||||
self.model,
|
||||
self.stream,
|
||||
key=self._key,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
while response:
|
||||
while current_response:
|
||||
count += 1
|
||||
yield response
|
||||
self._responses.append(response)
|
||||
if self.chain_limit and count > self.chain_limit:
|
||||
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
|
||||
yield current_response
|
||||
self._responses.append(current_response)
|
||||
if self.chain_limit and count >= self.chain_limit:
|
||||
raise ValueError(f"Chain limit of {self.chain_limit} exceeded.")
|
||||
|
||||
# This could raise llm.CancelToolCall:
|
||||
tool_results = response.execute_tool_calls(
|
||||
tool_results = current_response.execute_tool_calls(
|
||||
before_call=self.before_call, after_call=self.after_call
|
||||
)
|
||||
if tool_results:
|
||||
response = Response(
|
||||
current_response = Response(
|
||||
Prompt(
|
||||
"",
|
||||
"", # Next prompt is empty, tools drive it
|
||||
self.model,
|
||||
tools=response.prompt.tools,
|
||||
tools=current_response.prompt.tools,
|
||||
tool_results=tool_results,
|
||||
options=self.prompt.options,
|
||||
),
|
||||
|
|
@ -1124,27 +1337,71 @@ class _BaseChainResponse:
|
|||
conversation=self.conversation,
|
||||
)
|
||||
else:
|
||||
current_response = None
|
||||
break
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for response in self.responses():
|
||||
if isinstance(response, str):
|
||||
yield response
|
||||
else:
|
||||
yield from response
|
||||
for response_item in self.responses():
|
||||
yield from response_item
|
||||
|
||||
def text(self) -> str:
|
||||
return "".join(self)
|
||||
|
||||
def log_to_db(self, db):
|
||||
for response in self._responses:
|
||||
if isinstance(response, AsyncResponse):
|
||||
response = asyncio.run(response.to_sync_response())
|
||||
response.log_to_db(db)
|
||||
|
||||
class AsyncChainResponse(_BaseChainResponse):
|
||||
_responses: List["AsyncResponse"]
|
||||
|
||||
class ChainResponse(_BaseChainResponse):
|
||||
"Know how to chain multiple responses e.g. for tool calls"
|
||||
async def responses(self) -> AsyncIterator[AsyncResponse]:
|
||||
prompt = self.prompt
|
||||
count = 0
|
||||
current_response: Optional[AsyncResponse] = AsyncResponse(
|
||||
prompt,
|
||||
self.model,
|
||||
self.stream,
|
||||
key=self._key,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
while current_response:
|
||||
count += 1
|
||||
yield current_response
|
||||
self._responses.append(current_response)
|
||||
|
||||
if self.chain_limit and count >= self.chain_limit:
|
||||
raise ValueError(f"Chain limit of {self.chain_limit} exceeded.")
|
||||
|
||||
# This could raise llm.CancelToolCall:
|
||||
tool_results = await current_response.execute_tool_calls(
|
||||
before_call=self.before_call, after_call=self.after_call
|
||||
)
|
||||
if tool_results:
|
||||
prompt = Prompt(
|
||||
"",
|
||||
self.model,
|
||||
tools=current_response.prompt.tools,
|
||||
tool_results=tool_results,
|
||||
options=self.prompt.options,
|
||||
)
|
||||
current_response = AsyncResponse(
|
||||
prompt,
|
||||
self.model,
|
||||
stream=self.stream,
|
||||
key=self._key,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
else:
|
||||
current_response = None
|
||||
break
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[str]:
|
||||
async for response_item in self.responses():
|
||||
async for chunk in response_item:
|
||||
yield chunk
|
||||
|
||||
async def text(self) -> str:
|
||||
all_chunks = []
|
||||
async for chunk in self:
|
||||
all_chunks.append(chunk)
|
||||
return "".join(all_chunks)
|
||||
|
||||
|
||||
class Options(BaseModel):
|
||||
|
|
@ -1171,13 +1428,13 @@ class _get_key_mixin:
|
|||
return self.key
|
||||
|
||||
# Attempt to load a key using llm.get_key()
|
||||
key = get_key(
|
||||
key_value = get_key(
|
||||
explicit_key=explicit_key,
|
||||
key_alias=self.needs_key,
|
||||
env_var=self.key_env_var,
|
||||
)
|
||||
if key:
|
||||
return key
|
||||
if key_value:
|
||||
return key_value
|
||||
|
||||
# Show a useful error message
|
||||
message = "No key found - add one using 'llm keys set {}'".format(
|
||||
|
|
@ -1241,7 +1498,7 @@ class _Model(_BaseModel):
|
|||
tool_results: Optional[List[ToolResult]] = None,
|
||||
**options,
|
||||
) -> Response:
|
||||
key = options.pop("key", None)
|
||||
key_value = options.pop("key", None)
|
||||
self._validate_attachments(attachments)
|
||||
return Response(
|
||||
Prompt(
|
||||
|
|
@ -1258,7 +1515,7 @@ class _Model(_BaseModel):
|
|||
),
|
||||
self,
|
||||
stream,
|
||||
key=key,
|
||||
key=key_value,
|
||||
)
|
||||
|
||||
def chain(
|
||||
|
|
@ -1278,7 +1535,7 @@ class _Model(_BaseModel):
|
|||
details: bool = False,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
) -> ChainResponse:
|
||||
return self.conversation().chain(
|
||||
prompt=prompt,
|
||||
fragments=fragments,
|
||||
|
|
@ -1340,7 +1597,7 @@ class _AsyncModel(_BaseModel):
|
|||
stream: bool = True,
|
||||
**options,
|
||||
) -> AsyncResponse:
|
||||
key = options.pop("key", None)
|
||||
key_value = options.pop("key", None)
|
||||
self._validate_attachments(attachments)
|
||||
return AsyncResponse(
|
||||
Prompt(
|
||||
|
|
@ -1357,11 +1614,47 @@ class _AsyncModel(_BaseModel):
|
|||
),
|
||||
self,
|
||||
stream,
|
||||
key=key,
|
||||
key=key_value,
|
||||
)
|
||||
|
||||
def chain(self, *args, **kwargs):
|
||||
raise NotImplementedError("AsyncModel does not yet support tools")
|
||||
def chain(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
*,
|
||||
fragments: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
system_fragments: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
details: bool = False,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> AsyncChainResponse:
|
||||
return self.conversation().chain(
|
||||
prompt=prompt,
|
||||
fragments=fragments,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
system_fragments=system_fragments,
|
||||
stream=stream,
|
||||
schema=schema,
|
||||
tools=tools,
|
||||
tool_results=tool_results,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
details=details,
|
||||
key=key,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
class AsyncModel(_AsyncModel):
|
||||
|
|
@ -1373,7 +1666,9 @@ class AsyncModel(_AsyncModel):
|
|||
response: AsyncResponse,
|
||||
conversation: Optional[AsyncConversation],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
yield ""
|
||||
if False: # Ensure it's a generator type, but don't actually yield.
|
||||
yield ""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncKeyModel(_AsyncModel):
|
||||
|
|
@ -1386,7 +1681,9 @@ class AsyncKeyModel(_AsyncModel):
|
|||
conversation: Optional[AsyncConversation],
|
||||
key: Optional[str],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
yield ""
|
||||
if False: # Ensure it's a generator type
|
||||
yield ""
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingModel(ABC, _get_key_mixin):
|
||||
|
|
@ -1418,20 +1715,20 @@ class EmbeddingModel(ABC, _get_key_mixin):
|
|||
) -> Iterator[List[float]]:
|
||||
"Embed multiple items in batches according to the model batch_size"
|
||||
iter_items = iter(items)
|
||||
batch_size = self.batch_size if batch_size is None else batch_size
|
||||
effective_batch_size = self.batch_size if batch_size is None else batch_size
|
||||
if (not self.supports_binary) or (not self.supports_text):
|
||||
|
||||
def checking_iter(items):
|
||||
for item in items:
|
||||
self._check(item)
|
||||
yield item
|
||||
def checking_iter(inner_items):
|
||||
for item_to_check in inner_items:
|
||||
self._check(item_to_check)
|
||||
yield item_to_check
|
||||
|
||||
iter_items = checking_iter(items)
|
||||
if batch_size is None:
|
||||
if effective_batch_size is None:
|
||||
yield from self.embed_batch(iter_items)
|
||||
return
|
||||
while True:
|
||||
batch_items = list(islice(iter_items, batch_size))
|
||||
batch_items = list(islice(iter_items, effective_batch_size))
|
||||
if not batch_items:
|
||||
break
|
||||
yield from self.embed_batch(batch_items)
|
||||
|
|
@ -1457,14 +1754,14 @@ class ModelWithAliases:
|
|||
aliases: Set[str]
|
||||
|
||||
def matches(self, query: str) -> bool:
|
||||
query = query.lower()
|
||||
query_lower = query.lower()
|
||||
all_strings: List[str] = []
|
||||
all_strings.extend(self.aliases)
|
||||
if self.model:
|
||||
all_strings.append(str(self.model))
|
||||
if self.async_model:
|
||||
all_strings.append(str(self.async_model.model_id))
|
||||
return any(query in alias.lower() for alias in all_strings)
|
||||
return any(query_lower in alias.lower() for alias in all_strings)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -1473,11 +1770,11 @@ class EmbeddingModelWithAliases:
|
|||
aliases: Set[str]
|
||||
|
||||
def matches(self, query: str) -> bool:
|
||||
query = query.lower()
|
||||
query_lower = query.lower()
|
||||
all_strings: List[str] = []
|
||||
all_strings.extend(self.aliases)
|
||||
all_strings.append(str(self.model))
|
||||
return any(query in alias.lower() for alias in all_strings)
|
||||
return any(query_lower in alias.lower() for alias in all_strings)
|
||||
|
||||
|
||||
def _conversation_name(text):
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ test = [
|
|||
"types-click",
|
||||
"types-PyYAML",
|
||||
"types-setuptools",
|
||||
"llm-echo==0.3a2",
|
||||
"llm-echo==0.3a3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
|
|
@ -470,3 +470,11 @@ def collection():
|
|||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
return {"filter_headers": ["Authorization"]}
|
||||
|
||||
|
||||
def extract_braces(s):
|
||||
first = s.find("{")
|
||||
last = s.rfind("}")
|
||||
if first != -1 and last != -1 and first < last:
|
||||
return s[first : last + 1]
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
import json
|
||||
import llm
|
||||
from llm.migrations import migrate
|
||||
import os
|
||||
import pytest
|
||||
import sqlite_utils
|
||||
import time
|
||||
|
||||
|
||||
API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey"
|
||||
|
|
@ -92,3 +95,78 @@ def test_tool_use_chain_of_two_calls(vcr):
|
|||
assert second.tool_calls()[0].arguments == {"population": 123124}
|
||||
assert third.prompt.tool_results[0].output == "true"
|
||||
assert third.tool_calls() == []
|
||||
|
||||
|
||||
def test_tool_use_async_tool_function():
|
||||
async def hello():
|
||||
return "world"
|
||||
|
||||
model = llm.get_model("echo")
|
||||
chain_response = model.chain(
|
||||
json.dumps({"tool_calls": [{"name": "hello"}]}), tools=[hello]
|
||||
)
|
||||
output = chain_response.text()
|
||||
# That's two JSON objects separated by '\n}{\n'
|
||||
bits = output.split("\n}{\n")
|
||||
assert len(bits) == 2
|
||||
objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])]
|
||||
assert objects == [
|
||||
{"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []},
|
||||
{
|
||||
"prompt": "",
|
||||
"system": "",
|
||||
"attachments": [],
|
||||
"stream": True,
|
||||
"previous": [{"prompt": '{"tool_calls": [{"name": "hello"}]}'}],
|
||||
"tool_results": [
|
||||
{"name": "hello", "output": "world", "tool_call_id": None}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_tools_run_tools_in_parallel():
|
||||
start_timestamps = []
|
||||
|
||||
start_ns = time.monotonic_ns()
|
||||
|
||||
async def hello():
|
||||
start_timestamps.append(("hello", time.monotonic_ns() - start_ns))
|
||||
await asyncio.sleep(0.2)
|
||||
return "world"
|
||||
|
||||
async def hello2():
|
||||
start_timestamps.append(("hello2", time.monotonic_ns() - start_ns))
|
||||
await asyncio.sleep(0.2)
|
||||
return "world2"
|
||||
|
||||
model = llm.get_async_model("echo")
|
||||
chain_response = model.chain(
|
||||
json.dumps({"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}),
|
||||
tools=[hello, hello2],
|
||||
)
|
||||
output = await chain_response.text()
|
||||
# That's two JSON objects separated by '\n}{\n'
|
||||
bits = output.split("\n}{\n")
|
||||
assert len(bits) == 2
|
||||
objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])]
|
||||
assert objects == [
|
||||
{"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []},
|
||||
{
|
||||
"prompt": "",
|
||||
"system": "",
|
||||
"attachments": [],
|
||||
"stream": True,
|
||||
"previous": [
|
||||
{"prompt": '{"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}'}
|
||||
],
|
||||
"tool_results": [
|
||||
{"name": "hello", "output": "world", "tool_call_id": None},
|
||||
{"name": "hello2", "output": "world2", "tool_call_id": None},
|
||||
],
|
||||
},
|
||||
]
|
||||
delta_ns = start_timestamps[1][1] - start_timestamps[0][1]
|
||||
# They should have run in parallel so it should be less than 0.02s difference
|
||||
assert delta_ns < (100_000_000 * 0.2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue