Async tool support (#1063)

* Sync models can now call async tools, refs #987
* Test for async tool functions in sync context, refs #987
* Test for asyncio tools, plus test that they run in parallel
* Docs for async tool usage
This commit is contained in:
Simon Willison 2025-05-21 21:42:19 -07:00 committed by GitHub
parent 0ee1ba3a65
commit 3cb875fa3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 561 additions and 86 deletions

View file

@ -89,6 +89,19 @@ for chunk in model.chain(
```
This will stream each of the chain of responses in turn as they are generated.
You can access the individual responses that make up the chain using `chain.responses()`. This can be iterated over as the chain executes like this:
```python
chain = model.chain(
"Convert panda to upper",
tools=[upper],
)
for response in chain.responses():
print(response.prompt)
for chunk in response:
print(chunk, end="", flush=True)
```
(python-api-system-prompts)=
### System prompts
@ -351,10 +364,9 @@ model = llm.get_async_model("gpt-4o")
You can then run a prompt using `await model.prompt(...)`:
```python
response = await model.prompt(
print(await model.prompt(
"Five surprising names for a pet pelican"
)
print(await response.text())
).text())
```
Or use `async for chunk in ...` to stream the response as it is generated:
```python
@ -365,6 +377,56 @@ async for chunk in model.prompt(
```
This `await model.prompt()` method takes the same arguments as the synchronous `model.prompt()` method, for options and attachments and `key=` and suchlike.
(python-api-async-tools)=
### Tool functions can be sync or async
{ref}`Tool functions <python-api-tools>` can be both synchronous or asynchronous. The latter are defined using `async def tool_name(...)`. Either kind of function can be passed to the `tools=[...]` parameter.
If an `async def` function is used in a synchronous context LLM will automatically execute it in a thread pool using `asyncio.run()`. This means the following will work even in non-asynchronous Python scripts:
```python
async def hello(name: str) -> str:
"Say hello to name"
return "Hello there " + name
model = llm.get_model("gpt-4.1-mini")
chain_response = model.chain(
"Say hello to Percival", tools=[hello]
)
print(chain_response.text())
```
### Tool use for async models
Tool use is also supported for async models, using either synchronous or asynchronous tool functions. Synchronous functions will block the event loop so only use those in asynchronous context if you are certain they are extremely fast.
The `response.execute_tool_calls()` and `chain_response.text()` and `chain_response.responses()` methods must all be awaited when run against asynchronous models:
```python
import llm
model = llm.get_async_model("gpt-4.1")
def upper(string):
"Converts string to uppercase"
return string.upper()
chain = model.chain(
"Convert panda to uppercase then pelican to uppercase",
tools=[upper],
after_call=print
)
print(await chain.text())
```
To iterate over the chained response output as it arrives use `async for`:
```python
async for chunk in model.chain(
"Convert panda to uppercase then pelican to uppercase",
tools=[upper]
):
print(chunk, end="", flush=True)
```
(python-api-conversations)=
## Conversations

View file

@ -763,16 +763,38 @@ class AsyncChat(_Shared, AsyncKeyModel):
**kwargs,
)
chunks = []
tool_calls = {}
async for chunk in completion:
if chunk.usage:
usage = chunk.usage.model_dump()
chunks.append(chunk)
if chunk.usage:
usage = chunk.usage.model_dump()
if chunk.choices and chunk.choices[0].delta:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in tool_calls:
tool_calls[index] = tool_call
tool_calls[
index
].function.arguments += tool_call.function.arguments
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
if tool_calls:
for value in tool_calls.values():
# value.function looks like this:
# ChoiceDeltaToolCallFunction(arguments='{"city":"San Francisco"}', name='get_weather')
response.add_tool_call(
llm.ToolCall(
tool_call_id=value.id,
name=value.function.name,
arguments=json.loads(value.function.arguments),
)
)
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = await client.chat.completions.create(
@ -783,6 +805,14 @@ class AsyncChat(_Shared, AsyncKeyModel):
)
response.response_json = remove_dict_none_values(completion.model_dump())
usage = completion.usage.model_dump()
for tool_call in completion.choices[0].message.tool_calls or []:
response.add_tool_call(
llm.ToolCall(
tool_call_id=tool_call.id,
name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
)
)
if completion.choices[0].message.content is not None:
yield completion.choices[0].message.content
self.set_usage(response, usage)

View file

@ -12,6 +12,8 @@ import time
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
@ -375,6 +377,53 @@ class Conversation(_BaseConversation):
@dataclass
class AsyncConversation(_BaseConversation):
def chain(
self,
prompt: Optional[str] = None,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
chain_limit: Optional[int] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
details: bool = False,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> "AsyncChainResponse":
self.model._validate_attachments(attachments)
return AsyncChainResponse(
Prompt(
prompt,
fragments=fragments,
attachments=attachments,
system=system,
schema=schema,
tools=tools,
tool_results=tool_results,
system_fragments=system_fragments,
model=self.model,
options=self.model.Options(**(options or {})),
),
model=self.model,
stream=stream,
conversation=self,
key=key,
details=details,
before_call=before_call,
after_call=after_call,
chain_limit=chain_limit,
)
def prompt(
self,
prompt: Optional[str] = None,
@ -784,37 +833,59 @@ class Response(_BaseResponse):
def execute_tool_calls(
self,
*,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# TODO: make this work async
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
if tool is None:
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
if before_call:
# This may raise CancelToolCall:
before_call(tool, tool_call)
cb_result = before_call(tool, tool_call)
if inspect.isawaitable(cb_result):
raise TypeError(
"Asynchronous 'before_call' callback provided to a synchronous tool execution context. "
"Please use an async chain/response or a synchronous callback."
)
if not tool.implementation:
raise CancelToolCall(
"No implementation available for tool: {}".format(tool_call.name)
)
try:
result = tool.implementation(**tool_call.arguments)
if asyncio.iscoroutinefunction(tool.implementation):
result = asyncio.run(tool.implementation(**tool_call.arguments))
else:
result = tool.implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
result = f"Error: {ex}"
tool_result = ToolResult(
tool_result_obj = ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
if after_call:
after_call(tool, tool_call, tool_result)
tool_results.append(tool_result)
cb_result = after_call(tool, tool_call, tool_result_obj)
if inspect.isawaitable(cb_result):
raise TypeError(
"Asynchronous 'after_call' callback provided to a synchronous tool execution context. "
"Please use an async chain/response or a synchronous callback."
)
tool_results.append(tool_result_obj)
return tool_results
def tool_calls(self) -> List[ToolCall]:
@ -901,30 +972,131 @@ class AsyncResponse(_BaseResponse):
self.done_callbacks.append(callback)
else:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
# Ensure we handle both sync and async callbacks correctly
processed_callback = callback(self)
if inspect.isawaitable(processed_callback):
await processed_callback
elif inspect.isawaitable(callback):
await callback
async def _on_done(self):
for callback in self.done_callbacks:
if callable(callback):
callback = callback(self)
if asyncio.iscoroutine(callback):
await callback
for callback_func in self.done_callbacks:
if callable(callback_func):
processed_callback = callback_func(self)
if inspect.isawaitable(processed_callback):
await processed_callback
elif inspect.isawaitable(callback_func):
await callback_func
async def execute_tool_calls(
self,
*,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
) -> List[ToolResult]:
tool_calls_list = await self.tool_calls()
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
indexed_results: List[tuple[int, ToolResult]] = []
async_tasks: List[asyncio.Task] = []
for idx, tc in enumerate(tool_calls_list):
tool = tools_by_name.get(tc.name)
if tool is None:
raise CancelToolCall(f"Unknown tool: {tc.name}")
if not tool.implementation:
raise CancelToolCall(f"No implementation for tool: {tc.name}")
# If it's an async implementation, wrap it
if inspect.iscoroutinefunction(tool.implementation):
async def run_async(tc=tc, tool=tool, idx=idx):
# before_call inside the task
if before_call:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
try:
result = await tool.implementation(**tc.arguments)
output = (
result
if isinstance(result, str)
else json.dumps(result, default=repr)
)
except Exception as ex:
output = f"Error: {ex}"
tr = ToolResult(
name=tc.name,
output=output,
tool_call_id=tc.tool_call_id,
)
# after_call inside the task
if after_call:
cb2 = after_call(tool, tc, tr)
if inspect.isawaitable(cb2):
await cb2
return idx, tr
async_tasks.append(asyncio.create_task(run_async()))
else:
# Sync implementation: do hooks and call inline
if before_call:
cb = before_call(tool, tc)
if inspect.isawaitable(cb):
await cb
try:
res = tool.implementation(**tc.arguments)
if inspect.isawaitable(res):
res = await res
output = (
res if isinstance(res, str) else json.dumps(res, default=repr)
)
except Exception as ex:
output = f"Error: {ex}"
tr = ToolResult(
name=tc.name,
output=output,
tool_call_id=tc.tool_call_id,
)
if after_call:
cb2 = after_call(tool, tc, tr)
if inspect.isawaitable(cb2):
await cb2
indexed_results.append((idx, tr))
# Await all async tasks in parallel
if async_tasks:
indexed_results.extend(await asyncio.gather(*async_tasks))
# Reorder by original index
indexed_results.sort(key=lambda x: x[0])
return [tr for _, tr in indexed_results]
def __aiter__(self):
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.now(datetime.timezone.utc)
if self._done:
self._iter_chunks = list(self._chunks) # Make a copy for iteration
return self
async def __anext__(self) -> str:
if self._done:
if not self._chunks:
raise StopAsyncIteration
chunk = self._chunks.pop(0)
if not self._chunks:
raise StopAsyncIteration
return chunk
if hasattr(self, "_iter_chunks") and self._iter_chunks:
return self._iter_chunks.pop(0)
raise StopAsyncIteration
if not hasattr(self, "_generator"):
if isinstance(self.model, AsyncModel):
@ -955,13 +1127,17 @@ class AsyncResponse(_BaseResponse):
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
if hasattr(self, "_generator"):
del self._generator
await self._on_done()
raise
async def _force(self):
if not self._done:
async for _ in self:
pass
temp_chunks = []
async for chunk in self:
temp_chunks.append(chunk)
# This should populate self._chunks
return self
def text_or_raise(self) -> str:
@ -1007,20 +1183,44 @@ class AsyncResponse(_BaseResponse):
async def to_sync_response(self) -> Response:
await self._force()
# This conversion might be tricky if the model is AsyncModel,
# as Response expects a sync Model. For simplicity, we'll assume
# the primary use case is data transfer after completion.
# The model type on the new Response might need careful handling
# if it's intended for further execution.
# For now, let's assume self.model can be cast or is compatible.
sync_model = self.model
if not isinstance(self.model, (Model, KeyModel)):
# This is a placeholder. A proper conversion or shared base might be needed
# if the sync_response needs to be fully functional with its model.
# For now, we pass the async model, which might limit what sync_response can do.
pass
response = Response(
self.prompt,
self.model,
sync_model, # This might need adjustment based on how Model/AsyncModel relate
self.stream,
conversation=self.conversation,
# conversation type needs to be compatible too.
conversation=(
self.conversation.to_sync_conversation()
if self.conversation
and hasattr(self.conversation, "to_sync_conversation")
else None
),
)
response._chunks = self._chunks
response._done = True
response.id = self.id
response._chunks = list(self._chunks) # Copy chunks
response._done = self._done
response._end = self._end
response._start = self._start
response._start_utcnow = self._start_utcnow
response.input_tokens = self.input_tokens
response.output_tokens = self.output_tokens
response.token_details = self.token_details
response._prompt_json = self._prompt_json
response.response_json = self.response_json
response._tool_calls = list(self._tool_calls)
response.attachments = list(self.attachments)
return response
@classmethod
@ -1059,7 +1259,6 @@ class _BaseChainResponse:
stream: bool
conversation: Optional["_BaseConversation"] = None
_key: Optional[str] = None
_responses: List[Union["Response", "AsyncResponse"]]
def __init__(
self,
@ -1070,51 +1269,65 @@ class _BaseChainResponse:
key: Optional[str] = None,
details: bool = False,
chain_limit: Optional[int] = 10,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
):
self.prompt = prompt
self.model = model
self.stream = stream
self._key = key
self._details = details
# self._responses: List[Union[Response, AsyncResponse]] = []
self._responses: List[Response] = []
self._responses: List[Any] = []
self.conversation = conversation
self.chain_limit = chain_limit
self.before_call = before_call
self.after_call = after_call
def responses(
self, details=False
) -> Iterator[
Union[Response, str]
]: # Iterator[Union[Response, AsyncResponse, str]]:
def log_to_db(self, db):
for response in self._responses:
if isinstance(response, AsyncResponse):
sync_response = asyncio.run(response.to_sync_response())
elif isinstance(response, Response):
sync_response = response
else:
assert False, "Should have been a Response or AsyncResponse"
sync_response.log_to_db(db)
class ChainResponse(_BaseChainResponse):
_responses: List["Response"]
def responses(self) -> Iterator[Response]:
prompt = self.prompt
count = 0
response = Response(
current_response: Optional[Response] = Response(
prompt,
self.model,
self.stream,
key=self._key,
conversation=self.conversation,
)
while response:
while current_response:
count += 1
yield response
self._responses.append(response)
if self.chain_limit and count > self.chain_limit:
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
yield current_response
self._responses.append(current_response)
if self.chain_limit and count >= self.chain_limit:
raise ValueError(f"Chain limit of {self.chain_limit} exceeded.")
# This could raise llm.CancelToolCall:
tool_results = response.execute_tool_calls(
tool_results = current_response.execute_tool_calls(
before_call=self.before_call, after_call=self.after_call
)
if tool_results:
response = Response(
current_response = Response(
Prompt(
"",
"", # Next prompt is empty, tools drive it
self.model,
tools=response.prompt.tools,
tools=current_response.prompt.tools,
tool_results=tool_results,
options=self.prompt.options,
),
@ -1124,27 +1337,71 @@ class _BaseChainResponse:
conversation=self.conversation,
)
else:
current_response = None
break
def __iter__(self) -> Iterator[str]:
for response in self.responses():
if isinstance(response, str):
yield response
else:
yield from response
for response_item in self.responses():
yield from response_item
def text(self) -> str:
return "".join(self)
def log_to_db(self, db):
for response in self._responses:
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
response.log_to_db(db)
class AsyncChainResponse(_BaseChainResponse):
_responses: List["AsyncResponse"]
class ChainResponse(_BaseChainResponse):
"Know how to chain multiple responses e.g. for tool calls"
async def responses(self) -> AsyncIterator[AsyncResponse]:
prompt = self.prompt
count = 0
current_response: Optional[AsyncResponse] = AsyncResponse(
prompt,
self.model,
self.stream,
key=self._key,
conversation=self.conversation,
)
while current_response:
count += 1
yield current_response
self._responses.append(current_response)
if self.chain_limit and count >= self.chain_limit:
raise ValueError(f"Chain limit of {self.chain_limit} exceeded.")
# This could raise llm.CancelToolCall:
tool_results = await current_response.execute_tool_calls(
before_call=self.before_call, after_call=self.after_call
)
if tool_results:
prompt = Prompt(
"",
self.model,
tools=current_response.prompt.tools,
tool_results=tool_results,
options=self.prompt.options,
)
current_response = AsyncResponse(
prompt,
self.model,
stream=self.stream,
key=self._key,
conversation=self.conversation,
)
else:
current_response = None
break
async def __aiter__(self) -> AsyncIterator[str]:
async for response_item in self.responses():
async for chunk in response_item:
yield chunk
async def text(self) -> str:
all_chunks = []
async for chunk in self:
all_chunks.append(chunk)
return "".join(all_chunks)
class Options(BaseModel):
@ -1171,13 +1428,13 @@ class _get_key_mixin:
return self.key
# Attempt to load a key using llm.get_key()
key = get_key(
key_value = get_key(
explicit_key=explicit_key,
key_alias=self.needs_key,
env_var=self.key_env_var,
)
if key:
return key
if key_value:
return key_value
# Show a useful error message
message = "No key found - add one using 'llm keys set {}'".format(
@ -1241,7 +1498,7 @@ class _Model(_BaseModel):
tool_results: Optional[List[ToolResult]] = None,
**options,
) -> Response:
key = options.pop("key", None)
key_value = options.pop("key", None)
self._validate_attachments(attachments)
return Response(
Prompt(
@ -1258,7 +1515,7 @@ class _Model(_BaseModel):
),
self,
stream,
key=key,
key=key_value,
)
def chain(
@ -1278,7 +1535,7 @@ class _Model(_BaseModel):
details: bool = False,
key: Optional[str] = None,
options: Optional[dict] = None,
):
) -> ChainResponse:
return self.conversation().chain(
prompt=prompt,
fragments=fragments,
@ -1340,7 +1597,7 @@ class _AsyncModel(_BaseModel):
stream: bool = True,
**options,
) -> AsyncResponse:
key = options.pop("key", None)
key_value = options.pop("key", None)
self._validate_attachments(attachments)
return AsyncResponse(
Prompt(
@ -1357,11 +1614,47 @@ class _AsyncModel(_BaseModel):
),
self,
stream,
key=key,
key=key_value,
)
def chain(self, *args, **kwargs):
raise NotImplementedError("AsyncModel does not yet support tools")
def chain(
self,
prompt: Optional[str] = None,
*,
fragments: Optional[List[str]] = None,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
system_fragments: Optional[List[str]] = None,
stream: bool = True,
schema: Optional[Union[dict, type[BaseModel]]] = None,
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
before_call: Optional[
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
] = None,
after_call: Optional[
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
] = None,
details: bool = False,
key: Optional[str] = None,
options: Optional[dict] = None,
) -> AsyncChainResponse:
return self.conversation().chain(
prompt=prompt,
fragments=fragments,
attachments=attachments,
system=system,
system_fragments=system_fragments,
stream=stream,
schema=schema,
tools=tools,
tool_results=tool_results,
before_call=before_call,
after_call=after_call,
details=details,
key=key,
options=options,
)
class AsyncModel(_AsyncModel):
@ -1373,7 +1666,9 @@ class AsyncModel(_AsyncModel):
response: AsyncResponse,
conversation: Optional[AsyncConversation],
) -> AsyncGenerator[str, None]:
yield ""
if False: # Ensure it's a generator type, but don't actually yield.
yield ""
pass
class AsyncKeyModel(_AsyncModel):
@ -1386,7 +1681,9 @@ class AsyncKeyModel(_AsyncModel):
conversation: Optional[AsyncConversation],
key: Optional[str],
) -> AsyncGenerator[str, None]:
yield ""
if False: # Ensure it's a generator type
yield ""
pass
class EmbeddingModel(ABC, _get_key_mixin):
@ -1418,20 +1715,20 @@ class EmbeddingModel(ABC, _get_key_mixin):
) -> Iterator[List[float]]:
"Embed multiple items in batches according to the model batch_size"
iter_items = iter(items)
batch_size = self.batch_size if batch_size is None else batch_size
effective_batch_size = self.batch_size if batch_size is None else batch_size
if (not self.supports_binary) or (not self.supports_text):
def checking_iter(items):
for item in items:
self._check(item)
yield item
def checking_iter(inner_items):
for item_to_check in inner_items:
self._check(item_to_check)
yield item_to_check
iter_items = checking_iter(items)
if batch_size is None:
if effective_batch_size is None:
yield from self.embed_batch(iter_items)
return
while True:
batch_items = list(islice(iter_items, batch_size))
batch_items = list(islice(iter_items, effective_batch_size))
if not batch_items:
break
yield from self.embed_batch(batch_items)
@ -1457,14 +1754,14 @@ class ModelWithAliases:
aliases: Set[str]
def matches(self, query: str) -> bool:
query = query.lower()
query_lower = query.lower()
all_strings: List[str] = []
all_strings.extend(self.aliases)
if self.model:
all_strings.append(str(self.model))
if self.async_model:
all_strings.append(str(self.async_model.model_id))
return any(query in alias.lower() for alias in all_strings)
return any(query_lower in alias.lower() for alias in all_strings)
@dataclass
@ -1473,11 +1770,11 @@ class EmbeddingModelWithAliases:
aliases: Set[str]
def matches(self, query: str) -> bool:
query = query.lower()
query_lower = query.lower()
all_strings: List[str] = []
all_strings.extend(self.aliases)
all_strings.append(str(self.model))
return any(query in alias.lower() for alias in all_strings)
return any(query_lower in alias.lower() for alias in all_strings)
def _conversation_name(text):

View file

@ -67,7 +67,7 @@ test = [
"types-click",
"types-PyYAML",
"types-setuptools",
"llm-echo==0.3a2",
"llm-echo==0.3a3",
]
[build-system]

View file

@ -470,3 +470,11 @@ def collection():
@pytest.fixture(scope="module")
def vcr_config():
return {"filter_headers": ["Authorization"]}
def extract_braces(s):
first = s.find("{")
last = s.rfind("}")
if first != -1 and last != -1 and first < last:
return s[first : last + 1]
return None

View file

@ -1,8 +1,11 @@
import asyncio
import json
import llm
from llm.migrations import migrate
import os
import pytest
import sqlite_utils
import time
API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey"
@ -92,3 +95,78 @@ def test_tool_use_chain_of_two_calls(vcr):
assert second.tool_calls()[0].arguments == {"population": 123124}
assert third.prompt.tool_results[0].output == "true"
assert third.tool_calls() == []
def test_tool_use_async_tool_function():
async def hello():
return "world"
model = llm.get_model("echo")
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "hello"}]}), tools=[hello]
)
output = chain_response.text()
# That's two JSON objects separated by '\n}{\n'
bits = output.split("\n}{\n")
assert len(bits) == 2
objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])]
assert objects == [
{"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []},
{
"prompt": "",
"system": "",
"attachments": [],
"stream": True,
"previous": [{"prompt": '{"tool_calls": [{"name": "hello"}]}'}],
"tool_results": [
{"name": "hello", "output": "world", "tool_call_id": None}
],
},
]
@pytest.mark.asyncio
async def test_async_tools_run_tools_in_parallel():
start_timestamps = []
start_ns = time.monotonic_ns()
async def hello():
start_timestamps.append(("hello", time.monotonic_ns() - start_ns))
await asyncio.sleep(0.2)
return "world"
async def hello2():
start_timestamps.append(("hello2", time.monotonic_ns() - start_ns))
await asyncio.sleep(0.2)
return "world2"
model = llm.get_async_model("echo")
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}),
tools=[hello, hello2],
)
output = await chain_response.text()
# That's two JSON objects separated by '\n}{\n'
bits = output.split("\n}{\n")
assert len(bits) == 2
objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])]
assert objects == [
{"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []},
{
"prompt": "",
"system": "",
"attachments": [],
"stream": True,
"previous": [
{"prompt": '{"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}'}
],
"tool_results": [
{"name": "hello", "output": "world", "tool_call_id": None},
{"name": "hello2", "output": "world2", "tool_call_id": None},
],
},
]
delta_ns = start_timestamps[1][1] - start_timestamps[0][1]
# They should have run in parallel so it should be less than 0.02s difference
assert delta_ns < (100_000_000 * 0.2)