diff --git a/docs/python-api.md b/docs/python-api.md index 44c62e6..9665281 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -89,6 +89,19 @@ for chunk in model.chain( ``` This will stream each of the chain of responses in turn as they are generated. +You can access the individual responses that make up the chain using `chain.responses()`. This can be iterated over as the chain executes like this: + +```python +chain = model.chain( + "Convert panda to upper", + tools=[upper], +) +for response in chain.responses(): + print(response.prompt) + for chunk in response: + print(chunk, end="", flush=True) +``` + (python-api-system-prompts)= ### System prompts @@ -351,10 +364,9 @@ model = llm.get_async_model("gpt-4o") You can then run a prompt using `await model.prompt(...)`: ```python -response = await model.prompt( +print(await model.prompt( "Five surprising names for a pet pelican" -) -print(await response.text()) +).text()) ``` Or use `async for chunk in ...` to stream the response as it is generated: ```python @@ -365,6 +377,56 @@ async for chunk in model.prompt( ``` This `await model.prompt()` method takes the same arguments as the synchronous `model.prompt()` method, for options and attachments and `key=` and suchlike. +(python-api-async-tools)= + +### Tool functions can be sync or async + +{ref}`Tool functions ` can be both synchronous or asynchronous. The latter are defined using `async def tool_name(...)`. Either kind of function can be passed to the `tools=[...]` parameter. + +If an `async def` function is used in a synchronous context LLM will automatically execute it in a thread pool using `asyncio.run()`. This means the following will work even in non-asynchronous Python scripts: + +```python +async def hello(name: str) -> str: + "Say hello to name" + return "Hello there " + name + +model = llm.get_model("gpt-4.1-mini") +chain_response = model.chain( + "Say hello to Percival", tools=[hello] +) +print(chain_response.text()) +``` + +### Tool use for async models + +Tool use is also supported for async models, using either synchronous or asynchronous tool functions. Synchronous functions will block the event loop so only use those in asynchronous context if you are certain they are extremely fast. + +The `response.execute_tool_calls()` and `chain_response.text()` and `chain_response.responses()` methods must all be awaited when run against asynchronous models: + +```python +import llm +model = llm.get_async_model("gpt-4.1") + +def upper(string): + "Converts string to uppercase" + return string.upper() + +chain = model.chain( + "Convert panda to uppercase then pelican to uppercase", + tools=[upper], + after_call=print +) +print(await chain.text()) +``` + +To iterate over the chained response output as it arrives use `async for`: +```python +async for chunk in model.chain( + "Convert panda to uppercase then pelican to uppercase", + tools=[upper] +): + print(chunk, end="", flush=True) +``` (python-api-conversations)= ## Conversations diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 178740c..fc02011 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -763,16 +763,38 @@ class AsyncChat(_Shared, AsyncKeyModel): **kwargs, ) chunks = [] + tool_calls = {} async for chunk in completion: if chunk.usage: usage = chunk.usage.model_dump() chunks.append(chunk) + if chunk.usage: + usage = chunk.usage.model_dump() + if chunk.choices and chunk.choices[0].delta: + for tool_call in chunk.choices[0].delta.tool_calls or []: + index = tool_call.index + if index not in tool_calls: + tool_calls[index] = tool_call + tool_calls[ + index + ].function.arguments += tool_call.function.arguments try: content = chunk.choices[0].delta.content except IndexError: content = None if content is not None: yield content + if tool_calls: + for value in tool_calls.values(): + # value.function looks like this: + # ChoiceDeltaToolCallFunction(arguments='{"city":"San Francisco"}', name='get_weather') + response.add_tool_call( + llm.ToolCall( + tool_call_id=value.id, + name=value.function.name, + arguments=json.loads(value.function.arguments), + ) + ) response.response_json = remove_dict_none_values(combine_chunks(chunks)) else: completion = await client.chat.completions.create( @@ -783,6 +805,14 @@ class AsyncChat(_Shared, AsyncKeyModel): ) response.response_json = remove_dict_none_values(completion.model_dump()) usage = completion.usage.model_dump() + for tool_call in completion.choices[0].message.tool_calls or []: + response.add_tool_call( + llm.ToolCall( + tool_call_id=tool_call.id, + name=tool_call.function.name, + arguments=json.loads(tool_call.function.arguments), + ) + ) if completion.choices[0].message.content is not None: yield completion.choices[0].message.content self.set_usage(response, usage) diff --git a/llm/models.py b/llm/models.py index cd7c1ec..06393fb 100644 --- a/llm/models.py +++ b/llm/models.py @@ -12,6 +12,8 @@ import time from typing import ( Any, AsyncGenerator, + AsyncIterator, + Awaitable, Callable, Dict, Iterable, @@ -375,6 +377,53 @@ class Conversation(_BaseConversation): @dataclass class AsyncConversation(_BaseConversation): + def chain( + self, + prompt: Optional[str] = None, + *, + fragments: Optional[List[str]] = None, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, + stream: bool = True, + schema: Optional[Union[dict, type[BaseModel]]] = None, + tools: Optional[List[Tool]] = None, + tool_results: Optional[List[ToolResult]] = None, + chain_limit: Optional[int] = None, + before_call: Optional[ + Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] + ] = None, + after_call: Optional[ + Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + ] = None, + details: bool = False, + key: Optional[str] = None, + options: Optional[dict] = None, + ) -> "AsyncChainResponse": + self.model._validate_attachments(attachments) + return AsyncChainResponse( + Prompt( + prompt, + fragments=fragments, + attachments=attachments, + system=system, + schema=schema, + tools=tools, + tool_results=tool_results, + system_fragments=system_fragments, + model=self.model, + options=self.model.Options(**(options or {})), + ), + model=self.model, + stream=stream, + conversation=self, + key=key, + details=details, + before_call=before_call, + after_call=after_call, + chain_limit=chain_limit, + ) + def prompt( self, prompt: Optional[str] = None, @@ -784,37 +833,59 @@ class Response(_BaseResponse): def execute_tool_calls( self, *, - before_call: Optional[Callable[[Tool, ToolCall], None]] = None, - after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + before_call: Optional[ + Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] + ] = None, + after_call: Optional[ + Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + ] = None, ) -> List[ToolResult]: tool_results = [] tools_by_name = {tool.name: tool for tool in self.prompt.tools} - # TODO: make this work async for tool_call in self.tool_calls(): tool = tools_by_name.get(tool_call.name) if tool is None: raise CancelToolCall("Unknown tool: {}".format(tool_call.name)) + if before_call: # This may raise CancelToolCall: - before_call(tool, tool_call) + cb_result = before_call(tool, tool_call) + if inspect.isawaitable(cb_result): + raise TypeError( + "Asynchronous 'before_call' callback provided to a synchronous tool execution context. " + "Please use an async chain/response or a synchronous callback." + ) + if not tool.implementation: raise CancelToolCall( "No implementation available for tool: {}".format(tool_call.name) ) + try: - result = tool.implementation(**tool_call.arguments) + if asyncio.iscoroutinefunction(tool.implementation): + result = asyncio.run(tool.implementation(**tool_call.arguments)) + else: + result = tool.implementation(**tool_call.arguments) + if not isinstance(result, str): result = json.dumps(result, default=repr) except Exception as ex: result = f"Error: {ex}" - tool_result = ToolResult( + + tool_result_obj = ToolResult( name=tool_call.name, output=result, tool_call_id=tool_call.tool_call_id, ) + if after_call: - after_call(tool, tool_call, tool_result) - tool_results.append(tool_result) + cb_result = after_call(tool, tool_call, tool_result_obj) + if inspect.isawaitable(cb_result): + raise TypeError( + "Asynchronous 'after_call' callback provided to a synchronous tool execution context. " + "Please use an async chain/response or a synchronous callback." + ) + tool_results.append(tool_result_obj) return tool_results def tool_calls(self) -> List[ToolCall]: @@ -901,30 +972,131 @@ class AsyncResponse(_BaseResponse): self.done_callbacks.append(callback) else: if callable(callback): - callback = callback(self) - if asyncio.iscoroutine(callback): + # Ensure we handle both sync and async callbacks correctly + processed_callback = callback(self) + if inspect.isawaitable(processed_callback): + await processed_callback + elif inspect.isawaitable(callback): await callback async def _on_done(self): - for callback in self.done_callbacks: - if callable(callback): - callback = callback(self) - if asyncio.iscoroutine(callback): - await callback + for callback_func in self.done_callbacks: + if callable(callback_func): + processed_callback = callback_func(self) + if inspect.isawaitable(processed_callback): + await processed_callback + elif inspect.isawaitable(callback_func): + await callback_func + + async def execute_tool_calls( + self, + *, + before_call: Optional[ + Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] + ] = None, + after_call: Optional[ + Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + ] = None, + ) -> List[ToolResult]: + tool_calls_list = await self.tool_calls() + tools_by_name = {tool.name: tool for tool in self.prompt.tools} + + indexed_results: List[tuple[int, ToolResult]] = [] + async_tasks: List[asyncio.Task] = [] + + for idx, tc in enumerate(tool_calls_list): + tool = tools_by_name.get(tc.name) + if tool is None: + raise CancelToolCall(f"Unknown tool: {tc.name}") + if not tool.implementation: + raise CancelToolCall(f"No implementation for tool: {tc.name}") + + # If it's an async implementation, wrap it + if inspect.iscoroutinefunction(tool.implementation): + + async def run_async(tc=tc, tool=tool, idx=idx): + # before_call inside the task + if before_call: + cb = before_call(tool, tc) + if inspect.isawaitable(cb): + await cb + + try: + result = await tool.implementation(**tc.arguments) + output = ( + result + if isinstance(result, str) + else json.dumps(result, default=repr) + ) + except Exception as ex: + output = f"Error: {ex}" + + tr = ToolResult( + name=tc.name, + output=output, + tool_call_id=tc.tool_call_id, + ) + + # after_call inside the task + if after_call: + cb2 = after_call(tool, tc, tr) + if inspect.isawaitable(cb2): + await cb2 + + return idx, tr + + async_tasks.append(asyncio.create_task(run_async())) + + else: + # Sync implementation: do hooks and call inline + if before_call: + cb = before_call(tool, tc) + if inspect.isawaitable(cb): + await cb + + try: + res = tool.implementation(**tc.arguments) + if inspect.isawaitable(res): + res = await res + output = ( + res if isinstance(res, str) else json.dumps(res, default=repr) + ) + except Exception as ex: + output = f"Error: {ex}" + + tr = ToolResult( + name=tc.name, + output=output, + tool_call_id=tc.tool_call_id, + ) + + if after_call: + cb2 = after_call(tool, tc, tr) + if inspect.isawaitable(cb2): + await cb2 + + indexed_results.append((idx, tr)) + + # Await all async tasks in parallel + if async_tasks: + indexed_results.extend(await asyncio.gather(*async_tasks)) + + # Reorder by original index + indexed_results.sort(key=lambda x: x[0]) + return [tr for _, tr in indexed_results] def __aiter__(self): self._start = time.monotonic() self._start_utcnow = datetime.datetime.now(datetime.timezone.utc) + if self._done: + self._iter_chunks = list(self._chunks) # Make a copy for iteration return self async def __anext__(self) -> str: if self._done: - if not self._chunks: - raise StopAsyncIteration - chunk = self._chunks.pop(0) - if not self._chunks: - raise StopAsyncIteration - return chunk + if hasattr(self, "_iter_chunks") and self._iter_chunks: + return self._iter_chunks.pop(0) + raise StopAsyncIteration if not hasattr(self, "_generator"): if isinstance(self.model, AsyncModel): @@ -955,13 +1127,17 @@ class AsyncResponse(_BaseResponse): self.conversation.responses.append(self) self._end = time.monotonic() self._done = True + if hasattr(self, "_generator"): + del self._generator await self._on_done() raise async def _force(self): if not self._done: - async for _ in self: - pass + temp_chunks = [] + async for chunk in self: + temp_chunks.append(chunk) + # This should populate self._chunks return self def text_or_raise(self) -> str: @@ -1007,20 +1183,44 @@ class AsyncResponse(_BaseResponse): async def to_sync_response(self) -> Response: await self._force() + # This conversion might be tricky if the model is AsyncModel, + # as Response expects a sync Model. For simplicity, we'll assume + # the primary use case is data transfer after completion. + # The model type on the new Response might need careful handling + # if it's intended for further execution. + # For now, let's assume self.model can be cast or is compatible. + sync_model = self.model + if not isinstance(self.model, (Model, KeyModel)): + # This is a placeholder. A proper conversion or shared base might be needed + # if the sync_response needs to be fully functional with its model. + # For now, we pass the async model, which might limit what sync_response can do. + pass + response = Response( self.prompt, - self.model, + sync_model, # This might need adjustment based on how Model/AsyncModel relate self.stream, - conversation=self.conversation, + # conversation type needs to be compatible too. + conversation=( + self.conversation.to_sync_conversation() + if self.conversation + and hasattr(self.conversation, "to_sync_conversation") + else None + ), ) - response._chunks = self._chunks - response._done = True + response.id = self.id + response._chunks = list(self._chunks) # Copy chunks + response._done = self._done response._end = self._end response._start = self._start response._start_utcnow = self._start_utcnow response.input_tokens = self.input_tokens response.output_tokens = self.output_tokens response.token_details = self.token_details + response._prompt_json = self._prompt_json + response.response_json = self.response_json + response._tool_calls = list(self._tool_calls) + response.attachments = list(self.attachments) return response @classmethod @@ -1059,7 +1259,6 @@ class _BaseChainResponse: stream: bool conversation: Optional["_BaseConversation"] = None _key: Optional[str] = None - _responses: List[Union["Response", "AsyncResponse"]] def __init__( self, @@ -1070,51 +1269,65 @@ class _BaseChainResponse: key: Optional[str] = None, details: bool = False, chain_limit: Optional[int] = 10, - before_call: Optional[Callable[[Tool, ToolCall], None]] = None, - after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + before_call: Optional[ + Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] + ] = None, + after_call: Optional[ + Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + ] = None, ): self.prompt = prompt self.model = model self.stream = stream self._key = key self._details = details - # self._responses: List[Union[Response, AsyncResponse]] = [] - self._responses: List[Response] = [] + self._responses: List[Any] = [] self.conversation = conversation self.chain_limit = chain_limit self.before_call = before_call self.after_call = after_call - def responses( - self, details=False - ) -> Iterator[ - Union[Response, str] - ]: # Iterator[Union[Response, AsyncResponse, str]]: + def log_to_db(self, db): + for response in self._responses: + if isinstance(response, AsyncResponse): + sync_response = asyncio.run(response.to_sync_response()) + elif isinstance(response, Response): + sync_response = response + else: + assert False, "Should have been a Response or AsyncResponse" + sync_response.log_to_db(db) + + +class ChainResponse(_BaseChainResponse): + _responses: List["Response"] + + def responses(self) -> Iterator[Response]: prompt = self.prompt count = 0 - response = Response( + current_response: Optional[Response] = Response( prompt, self.model, self.stream, key=self._key, conversation=self.conversation, ) - while response: + while current_response: count += 1 - yield response - self._responses.append(response) - if self.chain_limit and count > self.chain_limit: - raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ") + yield current_response + self._responses.append(current_response) + if self.chain_limit and count >= self.chain_limit: + raise ValueError(f"Chain limit of {self.chain_limit} exceeded.") + # This could raise llm.CancelToolCall: - tool_results = response.execute_tool_calls( + tool_results = current_response.execute_tool_calls( before_call=self.before_call, after_call=self.after_call ) if tool_results: - response = Response( + current_response = Response( Prompt( - "", + "", # Next prompt is empty, tools drive it self.model, - tools=response.prompt.tools, + tools=current_response.prompt.tools, tool_results=tool_results, options=self.prompt.options, ), @@ -1124,27 +1337,71 @@ class _BaseChainResponse: conversation=self.conversation, ) else: + current_response = None break def __iter__(self) -> Iterator[str]: - for response in self.responses(): - if isinstance(response, str): - yield response - else: - yield from response + for response_item in self.responses(): + yield from response_item def text(self) -> str: return "".join(self) - def log_to_db(self, db): - for response in self._responses: - if isinstance(response, AsyncResponse): - response = asyncio.run(response.to_sync_response()) - response.log_to_db(db) +class AsyncChainResponse(_BaseChainResponse): + _responses: List["AsyncResponse"] -class ChainResponse(_BaseChainResponse): - "Know how to chain multiple responses e.g. for tool calls" + async def responses(self) -> AsyncIterator[AsyncResponse]: + prompt = self.prompt + count = 0 + current_response: Optional[AsyncResponse] = AsyncResponse( + prompt, + self.model, + self.stream, + key=self._key, + conversation=self.conversation, + ) + while current_response: + count += 1 + yield current_response + self._responses.append(current_response) + + if self.chain_limit and count >= self.chain_limit: + raise ValueError(f"Chain limit of {self.chain_limit} exceeded.") + + # This could raise llm.CancelToolCall: + tool_results = await current_response.execute_tool_calls( + before_call=self.before_call, after_call=self.after_call + ) + if tool_results: + prompt = Prompt( + "", + self.model, + tools=current_response.prompt.tools, + tool_results=tool_results, + options=self.prompt.options, + ) + current_response = AsyncResponse( + prompt, + self.model, + stream=self.stream, + key=self._key, + conversation=self.conversation, + ) + else: + current_response = None + break + + async def __aiter__(self) -> AsyncIterator[str]: + async for response_item in self.responses(): + async for chunk in response_item: + yield chunk + + async def text(self) -> str: + all_chunks = [] + async for chunk in self: + all_chunks.append(chunk) + return "".join(all_chunks) class Options(BaseModel): @@ -1171,13 +1428,13 @@ class _get_key_mixin: return self.key # Attempt to load a key using llm.get_key() - key = get_key( + key_value = get_key( explicit_key=explicit_key, key_alias=self.needs_key, env_var=self.key_env_var, ) - if key: - return key + if key_value: + return key_value # Show a useful error message message = "No key found - add one using 'llm keys set {}'".format( @@ -1241,7 +1498,7 @@ class _Model(_BaseModel): tool_results: Optional[List[ToolResult]] = None, **options, ) -> Response: - key = options.pop("key", None) + key_value = options.pop("key", None) self._validate_attachments(attachments) return Response( Prompt( @@ -1258,7 +1515,7 @@ class _Model(_BaseModel): ), self, stream, - key=key, + key=key_value, ) def chain( @@ -1278,7 +1535,7 @@ class _Model(_BaseModel): details: bool = False, key: Optional[str] = None, options: Optional[dict] = None, - ): + ) -> ChainResponse: return self.conversation().chain( prompt=prompt, fragments=fragments, @@ -1340,7 +1597,7 @@ class _AsyncModel(_BaseModel): stream: bool = True, **options, ) -> AsyncResponse: - key = options.pop("key", None) + key_value = options.pop("key", None) self._validate_attachments(attachments) return AsyncResponse( Prompt( @@ -1357,11 +1614,47 @@ class _AsyncModel(_BaseModel): ), self, stream, - key=key, + key=key_value, ) - def chain(self, *args, **kwargs): - raise NotImplementedError("AsyncModel does not yet support tools") + def chain( + self, + prompt: Optional[str] = None, + *, + fragments: Optional[List[str]] = None, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, + stream: bool = True, + schema: Optional[Union[dict, type[BaseModel]]] = None, + tools: Optional[List[Tool]] = None, + tool_results: Optional[List[ToolResult]] = None, + before_call: Optional[ + Callable[[Tool, ToolCall], Union[None, Awaitable[None]]] + ] = None, + after_call: Optional[ + Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]] + ] = None, + details: bool = False, + key: Optional[str] = None, + options: Optional[dict] = None, + ) -> AsyncChainResponse: + return self.conversation().chain( + prompt=prompt, + fragments=fragments, + attachments=attachments, + system=system, + system_fragments=system_fragments, + stream=stream, + schema=schema, + tools=tools, + tool_results=tool_results, + before_call=before_call, + after_call=after_call, + details=details, + key=key, + options=options, + ) class AsyncModel(_AsyncModel): @@ -1373,7 +1666,9 @@ class AsyncModel(_AsyncModel): response: AsyncResponse, conversation: Optional[AsyncConversation], ) -> AsyncGenerator[str, None]: - yield "" + if False: # Ensure it's a generator type, but don't actually yield. + yield "" + pass class AsyncKeyModel(_AsyncModel): @@ -1386,7 +1681,9 @@ class AsyncKeyModel(_AsyncModel): conversation: Optional[AsyncConversation], key: Optional[str], ) -> AsyncGenerator[str, None]: - yield "" + if False: # Ensure it's a generator type + yield "" + pass class EmbeddingModel(ABC, _get_key_mixin): @@ -1418,20 +1715,20 @@ class EmbeddingModel(ABC, _get_key_mixin): ) -> Iterator[List[float]]: "Embed multiple items in batches according to the model batch_size" iter_items = iter(items) - batch_size = self.batch_size if batch_size is None else batch_size + effective_batch_size = self.batch_size if batch_size is None else batch_size if (not self.supports_binary) or (not self.supports_text): - def checking_iter(items): - for item in items: - self._check(item) - yield item + def checking_iter(inner_items): + for item_to_check in inner_items: + self._check(item_to_check) + yield item_to_check iter_items = checking_iter(items) - if batch_size is None: + if effective_batch_size is None: yield from self.embed_batch(iter_items) return while True: - batch_items = list(islice(iter_items, batch_size)) + batch_items = list(islice(iter_items, effective_batch_size)) if not batch_items: break yield from self.embed_batch(batch_items) @@ -1457,14 +1754,14 @@ class ModelWithAliases: aliases: Set[str] def matches(self, query: str) -> bool: - query = query.lower() + query_lower = query.lower() all_strings: List[str] = [] all_strings.extend(self.aliases) if self.model: all_strings.append(str(self.model)) if self.async_model: all_strings.append(str(self.async_model.model_id)) - return any(query in alias.lower() for alias in all_strings) + return any(query_lower in alias.lower() for alias in all_strings) @dataclass @@ -1473,11 +1770,11 @@ class EmbeddingModelWithAliases: aliases: Set[str] def matches(self, query: str) -> bool: - query = query.lower() + query_lower = query.lower() all_strings: List[str] = [] all_strings.extend(self.aliases) all_strings.append(str(self.model)) - return any(query in alias.lower() for alias in all_strings) + return any(query_lower in alias.lower() for alias in all_strings) def _conversation_name(text): diff --git a/pyproject.toml b/pyproject.toml index 05512d0..4b0c058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ test = [ "types-click", "types-PyYAML", "types-setuptools", - "llm-echo==0.3a2", + "llm-echo==0.3a3", ] [build-system] diff --git a/tests/conftest.py b/tests/conftest.py index 7453d51..d19f1f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -470,3 +470,11 @@ def collection(): @pytest.fixture(scope="module") def vcr_config(): return {"filter_headers": ["Authorization"]} + + +def extract_braces(s): + first = s.find("{") + last = s.rfind("}") + if first != -1 and last != -1 and first < last: + return s[first : last + 1] + return None diff --git a/tests/test_tools.py b/tests/test_tools.py index 3d0ffed..553b47a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,8 +1,11 @@ +import asyncio +import json import llm from llm.migrations import migrate import os import pytest import sqlite_utils +import time API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey" @@ -92,3 +95,78 @@ def test_tool_use_chain_of_two_calls(vcr): assert second.tool_calls()[0].arguments == {"population": 123124} assert third.prompt.tool_results[0].output == "true" assert third.tool_calls() == [] + + +def test_tool_use_async_tool_function(): + async def hello(): + return "world" + + model = llm.get_model("echo") + chain_response = model.chain( + json.dumps({"tool_calls": [{"name": "hello"}]}), tools=[hello] + ) + output = chain_response.text() + # That's two JSON objects separated by '\n}{\n' + bits = output.split("\n}{\n") + assert len(bits) == 2 + objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])] + assert objects == [ + {"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []}, + { + "prompt": "", + "system": "", + "attachments": [], + "stream": True, + "previous": [{"prompt": '{"tool_calls": [{"name": "hello"}]}'}], + "tool_results": [ + {"name": "hello", "output": "world", "tool_call_id": None} + ], + }, + ] + + +@pytest.mark.asyncio +async def test_async_tools_run_tools_in_parallel(): + start_timestamps = [] + + start_ns = time.monotonic_ns() + + async def hello(): + start_timestamps.append(("hello", time.monotonic_ns() - start_ns)) + await asyncio.sleep(0.2) + return "world" + + async def hello2(): + start_timestamps.append(("hello2", time.monotonic_ns() - start_ns)) + await asyncio.sleep(0.2) + return "world2" + + model = llm.get_async_model("echo") + chain_response = model.chain( + json.dumps({"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}), + tools=[hello, hello2], + ) + output = await chain_response.text() + # That's two JSON objects separated by '\n}{\n' + bits = output.split("\n}{\n") + assert len(bits) == 2 + objects = [json.loads(bits[0] + "}"), json.loads("{" + bits[1])] + assert objects == [ + {"prompt": "", "system": "", "attachments": [], "stream": True, "previous": []}, + { + "prompt": "", + "system": "", + "attachments": [], + "stream": True, + "previous": [ + {"prompt": '{"tool_calls": [{"name": "hello"}, {"name": "hello2"}]}'} + ], + "tool_results": [ + {"name": "hello", "output": "world", "tool_call_id": None}, + {"name": "hello2", "output": "world2", "tool_call_id": None}, + ], + }, + ] + delta_ns = start_timestamps[1][1] - start_timestamps[0][1] + # They should have run in parallel so it should be less than 0.02s difference + assert delta_ns < (100_000_000 * 0.2)