diff --git a/docs/python-api.md b/docs/python-api.md index 441d181..4493116 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -148,6 +148,30 @@ for response in chain.responses(): print(chunk, end="", flush=True) ``` +(python-api-tools-attachments)= + +#### Tools can return attachments + +Tools can return {ref}`attachments ` in addition to returning text. Attachments that are returned from a tool call will be passed to the model as attachments for the next prompt in the chain. + +To return one or more attachments, return a `llm.ToolOutput` instance from your tool function. This can have an `output=` string and an `attachments=` list of `llm.Attachment` instances. + +Here's an example: +```python +import llm + +def generate_image(prompt: str) -> llm.ToolOutput: + """Generate an image based on the prompt.""" + image_content = generate_image_from_prompt(prompt) + return llm.ToolOutput( + output="Image generated successfully", + attachments=[llm.Attachment( + content=image_content, + mimetype="image/png" + )], + ) +``` + (python-api-toolbox)= #### Toolbox classes diff --git a/llm/__init__.py b/llm/__init__.py index f35a8b4..7e0d3ad 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -22,6 +22,7 @@ from .models import ( Tool, Toolbox, ToolCall, + ToolOutput, ToolResult, ) from .utils import schema_dsl, Fragment @@ -60,6 +61,7 @@ __all__ = [ "Tool", "Toolbox", "ToolCall", + "ToolOutput", "ToolResult", "user_dir", "schema_dsl", diff --git a/llm/cli.py b/llm/cli.py index 5dd0ca9..0af4d0c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1855,7 +1855,21 @@ def logs_list( 'tool_id', tr.tool_id, 'name', tr.name, 'output', tr.output, - 'tool_call_id', tr.tool_call_id + 'tool_call_id', tr.tool_call_id, + 'attachments', COALESCE( + (SELECT json_group_array(json_object( + 'id', a.id, + 'type', a.type, + 'path', a.path, + 'url', a.url, + 'content', a.content + )) + FROM tool_results_attachments tra + JOIN attachments a ON tra.attachment_id = a.id + WHERE tra.tool_result_id = tr.id + ), + '[]' + ) )) FROM tool_results tr WHERE tr.response_id = responses.id @@ -2066,11 +2080,24 @@ def logs_list( if row["tool_results"]: click.echo("\n### Tool results\n") for tool_result in row["tool_results"]: + attachments = "" + for attachment in tool_result["attachments"]: + desc = "" + if attachment.get("type"): + desc += attachment["type"] + ": " + if attachment.get("path"): + desc += attachment["path"] + elif attachment.get("url"): + desc += attachment["url"] + elif attachment.get("content"): + desc += f"<{attachment['content_length']:,} bytes>" + attachments += "\n - {}".format(desc) click.echo( - "- **{}**: `{}`
\n{}".format( + "- **{}**: `{}`
\n{}{}".format( tool_result["name"], tool_result["tool_call_id"], textwrap.indent(tool_result["output"], " "), + attachments, ) ) attachments = attachments_by_id.get(row["id"]) @@ -3885,10 +3912,17 @@ def _debug_tool_call(_, tool_call, tool_result): err=True, ) output = "" + attachments = "" + if tool_result.attachments: + attachments += "\nAttachments:\n" + for attachment in tool_result.attachments: + attachments += f" {repr(attachment)}\n" + try: output = json.dumps(json.loads(tool_result.output), indent=2) except ValueError: output = tool_result.output + output += attachments click.echo( click.style( textwrap.indent(output, " ") + "\n", diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 0005155..4ec7100 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -573,6 +573,14 @@ class _Shared: ) if prompt.system and prompt.system != current_system: messages.append({"role": "system", "content": prompt.system}) + for tool_result in prompt.tool_results: + messages.append( + { + "role": "tool", + "tool_call_id": tool_result.tool_call_id, + "content": tool_result.output, + } + ) if not prompt.attachments: if prompt.prompt: messages.append({"role": "user", "content": prompt.prompt or ""}) @@ -583,14 +591,6 @@ class _Shared: for attachment in prompt.attachments: attachment_message.append(_attachment(attachment)) messages.append({"role": "user", "content": attachment_message}) - for tool_result in prompt.tool_results: - messages.append( - { - "role": "tool", - "tool_call_id": tool_result.tool_call_id, - "content": tool_result.output, - } - ) return messages def set_usage(self, response, usage): diff --git a/llm/migrations.py b/llm/migrations.py index ccbd806..54cb777 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -397,3 +397,19 @@ def m019_resolved_model(db): # For models like gemini-1.5-flash-latest where we wish to record # the resolved model name in addition to the alias db["responses"].add_column("resolved_model", str) + + +@migration +def m020_tool_results_attachments(db): + db["tool_results_attachments"].create( + { + "tool_result_id": int, + "attachment_id": str, + "order": int, + }, + foreign_keys=( + ("tool_result_id", "tool_results", "id"), + ("attachment_id", "attachments", "id"), + ), + pk=("tool_result_id", "attachment_id"), + ) diff --git a/llm/models.py b/llm/models.py index a80b4db..9391ff4 100644 --- a/llm/models.py +++ b/llm/models.py @@ -97,6 +97,18 @@ class Attachment: def base64_content(self): return base64.b64encode(self.content_bytes()).decode("utf-8") + def __repr__(self): + info = [f"" + @classmethod def from_row(cls, row): return cls( @@ -261,10 +273,19 @@ class ToolCall: class ToolResult: name: str output: str + attachments: List[Attachment] = field(default_factory=list) tool_call_id: Optional[str] = None instance: Optional[Toolbox] = None +@dataclass +class ToolOutput: + "Tool functions can return output with extra attachments" + + output: Optional[Union[str, dict, list, bool, int, float]] = None + attachments: List[Attachment] = field(default_factory=list) + + class CancelToolCall(Exception): pass @@ -887,16 +908,40 @@ class _BaseResponse: instance_id = tool_result.instance.instance_id except AttributeError: pass - db["tool_results"].insert( - { - "response_id": response_id, - "tool_id": tool_ids_by_name.get(tool_result.name) or None, - "name": tool_result.name, - "output": tool_result.output, - "tool_call_id": tool_result.tool_call_id, - "instance_id": instance_id, - } + tool_result_id = ( + db["tool_results"] + .insert( + { + "response_id": response_id, + "tool_id": tool_ids_by_name.get(tool_result.name) or None, + "name": tool_result.name, + "output": tool_result.output, + "tool_call_id": tool_result.tool_call_id, + "instance_id": instance_id, + } + ) + .last_pk ) + # Persist attachments for tool results + for index, attachment in enumerate(tool_result.attachments): + attachment_id = attachment.id() + db["attachments"].insert( + { + "id": attachment_id, + "type": attachment.resolve_type(), + "path": attachment.path, + "url": attachment.url, + "content": attachment.content, + }, + replace=True, + ) + db["tool_results_attachments"].insert( + { + "tool_result_id": tool_result_id, + "attachment_id": attachment_id, + "order": index, + }, + ) class Response(_BaseResponse): @@ -964,12 +1009,18 @@ class Response(_BaseResponse): "No implementation available for tool: {}".format(tool_call.name) ) + attachments = [] + try: if asyncio.iscoroutinefunction(tool.implementation): result = asyncio.run(tool.implementation(**tool_call.arguments)) else: result = tool.implementation(**tool_call.arguments) + if isinstance(result, ToolOutput): + attachments = result.attachments + result = result.output + if not isinstance(result, str): result = json.dumps(result, default=repr) except Exception as ex: @@ -978,6 +1029,7 @@ class Response(_BaseResponse): tool_result_obj = ToolResult( name=tool_call.name, output=result, + attachments=attachments, tool_call_id=tool_call.tool_call_id, instance=_get_instance(tool.implementation), ) @@ -1125,8 +1177,12 @@ class AsyncResponse(_BaseResponse): if inspect.isawaitable(cb): await cb + attachments = [] try: result = await tool.implementation(**tc.arguments) + if isinstance(result, ToolOutput): + attachments.extend(result.attachments) + result = result.output output = ( result if isinstance(result, str) @@ -1138,6 +1194,7 @@ class AsyncResponse(_BaseResponse): tr = ToolResult( name=tc.name, output=output, + attachments=attachments, tool_call_id=tc.tool_call_id, instance=_get_instance(tool.implementation), ) @@ -1159,10 +1216,14 @@ class AsyncResponse(_BaseResponse): if inspect.isawaitable(cb): await cb + attachments = [] try: res = tool.implementation(**tc.arguments) if inspect.isawaitable(res): res = await res + if isinstance(res, ToolOutput): + attachments.extend(res.attachments) + res = res.output output = ( res if isinstance(res, str) else json.dumps(res, default=repr) ) @@ -1172,6 +1233,7 @@ class AsyncResponse(_BaseResponse): tr = ToolResult( name=tc.name, output=output, + attachments=attachments, tool_call_id=tc.tool_call_id, instance=_get_instance(tool.implementation), ) @@ -1427,6 +1489,9 @@ class ChainResponse(_BaseChainResponse): tool_results = current_response.execute_tool_calls( before_call=self.before_call, after_call=self.after_call ) + attachments = [] + for tool_result in tool_results: + attachments.extend(tool_result.attachments) if tool_results: current_response = Response( Prompt( @@ -1435,6 +1500,7 @@ class ChainResponse(_BaseChainResponse): tools=current_response.prompt.tools, tool_results=tool_results, options=self.prompt.options, + attachments=attachments, ), self.model, stream=self.stream, @@ -1479,12 +1545,16 @@ class AsyncChainResponse(_BaseChainResponse): before_call=self.before_call, after_call=self.after_call ) if tool_results: + attachments = [] + for tool_result in tool_results: + attachments.extend(tool_result.attachments) prompt = Prompt( "", self.model, tools=current_response.prompt.tools, tool_results=tool_results, options=self.prompt.options, + attachments=attachments, ) current_response = AsyncResponse( prompt, diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 7c5e4cf..c47495b 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -459,12 +459,12 @@ def test_register_tools(tmpdir, logs_db): ('{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}', "[]"), ( "", - '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null}]', + '[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "attachments": []}]', ), ('{"tool_calls": [{"name": "upper", "arguments": {"text": "two"}}]}', "[]"), ( "", - '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null}]', + '[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "attachments": []}]', ), ( '{"tool_calls": [{"name": "upper", "arguments": {"text": "three"}}]}', @@ -472,7 +472,7 @@ def test_register_tools(tmpdir, logs_db): ), ( "", - '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null}]', + '[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "attachments": []}]', ), ) # Test the --td option diff --git a/tests/test_tools.py b/tests/test_tools.py index a22b9c7..ecb61b1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -296,3 +296,50 @@ def test_incorrect_tool_usage(): ) output = chain_response.text() assert 'Error: tool \\"bad_tool\\" does not exist' in output + + +def test_tool_returning_attachment(): + model = llm.get_model("echo") + + def return_attachment() -> llm.Attachment: + return llm.ToolOutput( + "Output", + attachments=[ + llm.Attachment( + content=b"This is a test attachment", + type="image/png", + ) + ], + ) + + chain_response = model.chain( + json.dumps({"tool_calls": [{"name": "return_attachment"}]}), + tools=[return_attachment], + ) + output = chain_response.text() + assert '"type": "image/png"' in output + assert '"output": "Output"' in output + + +@pytest.mark.asyncio +async def test_async_tool_returning_attachment(): + model = llm.get_async_model("echo") + + async def return_attachment() -> llm.Attachment: + return llm.ToolOutput( + "Output", + attachments=[ + llm.Attachment( + content=b"This is a test attachment", + type="image/png", + ) + ], + ) + + chain_response = model.chain( + json.dumps({"tool_calls": [{"name": "return_attachment"}]}), + tools=[return_attachment], + ) + output = await chain_response.text() + assert '"type": "image/png"' in output + assert '"output": "Output"' in output