Tools can now return attachments

Closes #1014

- llm.ToolOutput(output='...', attachments=[...]) for tools to return attachments
- New table: `tool_results_attachments`
- Table is populated when tools return attachments
- llm --tools-debug shows attachments returned by tools
- llm logs shows attachments returned by tools
This commit is contained in:
Simon Willison 2025-06-01 10:08:36 -07:00 committed by GitHub
parent f74e242442
commit b5d1c5ee90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 215 additions and 22 deletions

View file

@ -148,6 +148,30 @@ for response in chain.responses():
print(chunk, end="", flush=True)
```
(python-api-tools-attachments)=
#### Tools can return attachments
Tools can return {ref}`attachments <python-api-attachments>` in addition to returning text. Attachments that are returned from a tool call will be passed to the model as attachments for the next prompt in the chain.
To return one or more attachments, return a `llm.ToolOutput` instance from your tool function. This can have an `output=` string and an `attachments=` list of `llm.Attachment` instances.
Here's an example:
```python
import llm
def generate_image(prompt: str) -> llm.ToolOutput:
"""Generate an image based on the prompt."""
image_content = generate_image_from_prompt(prompt)
return llm.ToolOutput(
output="Image generated successfully",
attachments=[llm.Attachment(
content=image_content,
mimetype="image/png"
)],
)
```
(python-api-toolbox)=
#### Toolbox classes

View file

@ -22,6 +22,7 @@ from .models import (
Tool,
Toolbox,
ToolCall,
ToolOutput,
ToolResult,
)
from .utils import schema_dsl, Fragment
@ -60,6 +61,7 @@ __all__ = [
"Tool",
"Toolbox",
"ToolCall",
"ToolOutput",
"ToolResult",
"user_dir",
"schema_dsl",

View file

@ -1855,7 +1855,21 @@ def logs_list(
'tool_id', tr.tool_id,
'name', tr.name,
'output', tr.output,
'tool_call_id', tr.tool_call_id
'tool_call_id', tr.tool_call_id,
'attachments', COALESCE(
(SELECT json_group_array(json_object(
'id', a.id,
'type', a.type,
'path', a.path,
'url', a.url,
'content', a.content
))
FROM tool_results_attachments tra
JOIN attachments a ON tra.attachment_id = a.id
WHERE tra.tool_result_id = tr.id
),
'[]'
)
))
FROM tool_results tr
WHERE tr.response_id = responses.id
@ -2066,11 +2080,24 @@ def logs_list(
if row["tool_results"]:
click.echo("\n### Tool results\n")
for tool_result in row["tool_results"]:
attachments = ""
for attachment in tool_result["attachments"]:
desc = ""
if attachment.get("type"):
desc += attachment["type"] + ": "
if attachment.get("path"):
desc += attachment["path"]
elif attachment.get("url"):
desc += attachment["url"]
elif attachment.get("content"):
desc += f"<{attachment['content_length']:,} bytes>"
attachments += "\n - {}".format(desc)
click.echo(
"- **{}**: `{}`<br>\n{}".format(
"- **{}**: `{}`<br>\n{}{}".format(
tool_result["name"],
tool_result["tool_call_id"],
textwrap.indent(tool_result["output"], " "),
attachments,
)
)
attachments = attachments_by_id.get(row["id"])
@ -3885,10 +3912,17 @@ def _debug_tool_call(_, tool_call, tool_result):
err=True,
)
output = ""
attachments = ""
if tool_result.attachments:
attachments += "\nAttachments:\n"
for attachment in tool_result.attachments:
attachments += f" {repr(attachment)}\n"
try:
output = json.dumps(json.loads(tool_result.output), indent=2)
except ValueError:
output = tool_result.output
output += attachments
click.echo(
click.style(
textwrap.indent(output, " ") + "\n",

View file

@ -573,6 +573,14 @@ class _Shared:
)
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
for tool_result in prompt.tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tool_result.tool_call_id,
"content": tool_result.output,
}
)
if not prompt.attachments:
if prompt.prompt:
messages.append({"role": "user", "content": prompt.prompt or ""})
@ -583,14 +591,6 @@ class _Shared:
for attachment in prompt.attachments:
attachment_message.append(_attachment(attachment))
messages.append({"role": "user", "content": attachment_message})
for tool_result in prompt.tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tool_result.tool_call_id,
"content": tool_result.output,
}
)
return messages
def set_usage(self, response, usage):

View file

@ -397,3 +397,19 @@ def m019_resolved_model(db):
# For models like gemini-1.5-flash-latest where we wish to record
# the resolved model name in addition to the alias
db["responses"].add_column("resolved_model", str)
@migration
def m020_tool_results_attachments(db):
db["tool_results_attachments"].create(
{
"tool_result_id": int,
"attachment_id": str,
"order": int,
},
foreign_keys=(
("tool_result_id", "tool_results", "id"),
("attachment_id", "attachments", "id"),
),
pk=("tool_result_id", "attachment_id"),
)

View file

@ -97,6 +97,18 @@ class Attachment:
def base64_content(self):
return base64.b64encode(self.content_bytes()).decode("utf-8")
def __repr__(self):
info = [f"<Attachment: {self.id()}"]
if self.type:
info.append(f'type="{self.type}"')
if self.path:
info.append(f'path="{self.path}"')
if self.url:
info.append(f'url="{self.url}"')
if self.content:
info.append(f"content={len(self.content)} bytes")
return " ".join(info) + ">"
@classmethod
def from_row(cls, row):
return cls(
@ -261,10 +273,19 @@ class ToolCall:
class ToolResult:
name: str
output: str
attachments: List[Attachment] = field(default_factory=list)
tool_call_id: Optional[str] = None
instance: Optional[Toolbox] = None
@dataclass
class ToolOutput:
"Tool functions can return output with extra attachments"
output: Optional[Union[str, dict, list, bool, int, float]] = None
attachments: List[Attachment] = field(default_factory=list)
class CancelToolCall(Exception):
pass
@ -887,16 +908,40 @@ class _BaseResponse:
instance_id = tool_result.instance.instance_id
except AttributeError:
pass
db["tool_results"].insert(
{
"response_id": response_id,
"tool_id": tool_ids_by_name.get(tool_result.name) or None,
"name": tool_result.name,
"output": tool_result.output,
"tool_call_id": tool_result.tool_call_id,
"instance_id": instance_id,
}
tool_result_id = (
db["tool_results"]
.insert(
{
"response_id": response_id,
"tool_id": tool_ids_by_name.get(tool_result.name) or None,
"name": tool_result.name,
"output": tool_result.output,
"tool_call_id": tool_result.tool_call_id,
"instance_id": instance_id,
}
)
.last_pk
)
# Persist attachments for tool results
for index, attachment in enumerate(tool_result.attachments):
attachment_id = attachment.id()
db["attachments"].insert(
{
"id": attachment_id,
"type": attachment.resolve_type(),
"path": attachment.path,
"url": attachment.url,
"content": attachment.content,
},
replace=True,
)
db["tool_results_attachments"].insert(
{
"tool_result_id": tool_result_id,
"attachment_id": attachment_id,
"order": index,
},
)
class Response(_BaseResponse):
@ -964,12 +1009,18 @@ class Response(_BaseResponse):
"No implementation available for tool: {}".format(tool_call.name)
)
attachments = []
try:
if asyncio.iscoroutinefunction(tool.implementation):
result = asyncio.run(tool.implementation(**tool_call.arguments))
else:
result = tool.implementation(**tool_call.arguments)
if isinstance(result, ToolOutput):
attachments = result.attachments
result = result.output
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
@ -978,6 +1029,7 @@ class Response(_BaseResponse):
tool_result_obj = ToolResult(
name=tool_call.name,
output=result,
attachments=attachments,
tool_call_id=tool_call.tool_call_id,
instance=_get_instance(tool.implementation),
)
@ -1125,8 +1177,12 @@ class AsyncResponse(_BaseResponse):
if inspect.isawaitable(cb):
await cb
attachments = []
try:
result = await tool.implementation(**tc.arguments)
if isinstance(result, ToolOutput):
attachments.extend(result.attachments)
result = result.output
output = (
result
if isinstance(result, str)
@ -1138,6 +1194,7 @@ class AsyncResponse(_BaseResponse):
tr = ToolResult(
name=tc.name,
output=output,
attachments=attachments,
tool_call_id=tc.tool_call_id,
instance=_get_instance(tool.implementation),
)
@ -1159,10 +1216,14 @@ class AsyncResponse(_BaseResponse):
if inspect.isawaitable(cb):
await cb
attachments = []
try:
res = tool.implementation(**tc.arguments)
if inspect.isawaitable(res):
res = await res
if isinstance(res, ToolOutput):
attachments.extend(res.attachments)
res = res.output
output = (
res if isinstance(res, str) else json.dumps(res, default=repr)
)
@ -1172,6 +1233,7 @@ class AsyncResponse(_BaseResponse):
tr = ToolResult(
name=tc.name,
output=output,
attachments=attachments,
tool_call_id=tc.tool_call_id,
instance=_get_instance(tool.implementation),
)
@ -1427,6 +1489,9 @@ class ChainResponse(_BaseChainResponse):
tool_results = current_response.execute_tool_calls(
before_call=self.before_call, after_call=self.after_call
)
attachments = []
for tool_result in tool_results:
attachments.extend(tool_result.attachments)
if tool_results:
current_response = Response(
Prompt(
@ -1435,6 +1500,7 @@ class ChainResponse(_BaseChainResponse):
tools=current_response.prompt.tools,
tool_results=tool_results,
options=self.prompt.options,
attachments=attachments,
),
self.model,
stream=self.stream,
@ -1479,12 +1545,16 @@ class AsyncChainResponse(_BaseChainResponse):
before_call=self.before_call, after_call=self.after_call
)
if tool_results:
attachments = []
for tool_result in tool_results:
attachments.extend(tool_result.attachments)
prompt = Prompt(
"",
self.model,
tools=current_response.prompt.tools,
tool_results=tool_results,
options=self.prompt.options,
attachments=attachments,
)
current_response = AsyncResponse(
prompt,

View file

@ -459,12 +459,12 @@ def test_register_tools(tmpdir, logs_db):
('{"tool_calls": [{"name": "upper", "arguments": {"text": "one"}}]}', "[]"),
(
"",
'[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null}]',
'[{"id": 2, "tool_id": 1, "name": "upper", "output": "ONE", "tool_call_id": null, "attachments": []}]',
),
('{"tool_calls": [{"name": "upper", "arguments": {"text": "two"}}]}', "[]"),
(
"",
'[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null}]',
'[{"id": 3, "tool_id": 1, "name": "upper", "output": "TWO", "tool_call_id": null, "attachments": []}]',
),
(
'{"tool_calls": [{"name": "upper", "arguments": {"text": "three"}}]}',
@ -472,7 +472,7 @@ def test_register_tools(tmpdir, logs_db):
),
(
"",
'[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null}]',
'[{"id": 4, "tool_id": 1, "name": "upper", "output": "THREE", "tool_call_id": null, "attachments": []}]',
),
)
# Test the --td option

View file

@ -296,3 +296,50 @@ def test_incorrect_tool_usage():
)
output = chain_response.text()
assert 'Error: tool \\"bad_tool\\" does not exist' in output
def test_tool_returning_attachment():
model = llm.get_model("echo")
def return_attachment() -> llm.Attachment:
return llm.ToolOutput(
"Output",
attachments=[
llm.Attachment(
content=b"This is a test attachment",
type="image/png",
)
],
)
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "return_attachment"}]}),
tools=[return_attachment],
)
output = chain_response.text()
assert '"type": "image/png"' in output
assert '"output": "Output"' in output
@pytest.mark.asyncio
async def test_async_tool_returning_attachment():
model = llm.get_async_model("echo")
async def return_attachment() -> llm.Attachment:
return llm.ToolOutput(
"Output",
attachments=[
llm.Attachment(
content=b"This is a test attachment",
type="image/png",
)
],
)
chain_response = model.chain(
json.dumps({"tool_calls": [{"name": "return_attachment"}]}),
tools=[return_attachment],
)
output = await chain_response.text()
assert '"type": "image/png"' in output
assert '"output": "Output"' in output