Attachment and attachment type support for templates, closes #826

This commit is contained in:
Simon Willison 2025-04-07 07:20:49 -07:00
parent d0255a1eda
commit 388a4ea9a2
3 changed files with 121 additions and 45 deletions

View file

@ -151,54 +151,84 @@ def resolve_fragments(
return resolved
class AttachmentError(Exception):
"""Exception raised for errors in attachment resolution."""
pass
def resolve_attachment(value):
"""
Resolve an attachment from a string value which could be:
- "-" for stdin
- A URL
- A file path
Returns an Attachment object.
Raises AttachmentError if the attachment cannot be resolved.
"""
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
mimetype = mimetype_from_string(content)
if mimetype is None:
raise AttachmentError("Could not determine mimetype of stdin")
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
response = httpx.head(value)
response.raise_for_status()
mimetype = response.headers.get("content-type")
except httpx.HTTPError as ex:
raise AttachmentError(str(ex))
return Attachment(type=mimetype, path=None, url=value, content=None)
# Check that the file exists
path = pathlib.Path(value)
if not path.exists():
raise AttachmentError(f"File {value} does not exist")
path = path.resolve()
# Try to guess type
mimetype = mimetype_from_path(str(path))
if mimetype is None:
raise AttachmentError(f"Could not determine mimetype of {value}")
return Attachment(type=mimetype, path=str(path), url=None, content=None)
class AttachmentType(click.ParamType):
name = "attachment"
def convert(self, value, param, ctx):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
mimetype = mimetype_from_string(content)
if mimetype is None:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
response = httpx.head(value)
response.raise_for_status()
mimetype = response.headers.get("content-type")
except httpx.HTTPError as ex:
raise click.BadParameter(str(ex))
return Attachment(mimetype, None, value, None)
# Check that the file exists
try:
return resolve_attachment(value)
except AttachmentError as e:
self.fail(str(e), param, ctx)
def resolve_attachment_with_type(value: str, mimetype: str) -> Attachment:
if "://" in value:
attachment = Attachment(mimetype, None, value, None)
elif value == "-":
content = sys.stdin.buffer.read()
attachment = Attachment(mimetype, None, None, content)
else:
# Look for file
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
raise click.BadParameter(f"File {value} does not exist")
path = path.resolve()
# Try to guess type
mimetype = mimetype_from_path(str(path))
if mimetype is None:
raise click.BadParameter(f"Could not determine mimetype of {value}")
return Attachment(type=mimetype, path=str(path), url=None, content=None)
attachment = Attachment(mimetype, str(path), None, None)
return attachment
def attachment_types_callback(ctx, param, values):
def attachment_types_callback(ctx, param, values) -> List[Attachment]:
collected = []
for value, mimetype in values:
if "://" in value:
attachment = Attachment(mimetype, None, value, None)
elif value == "-":
content = sys.stdin.buffer.read()
attachment = Attachment(mimetype, None, None, content)
else:
# Look for file
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
path = path.resolve()
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
collected.append(resolve_attachment_with_type(value, mimetype))
return collected
@ -508,6 +538,17 @@ def prompt(
to_save["fragments"] = list(fragments)
if system_fragments:
to_save["system_fragments"] = list(system_fragments)
if attachments:
# Only works for attachments with a path or url
to_save["attachments"] = [
(a.path or a.url) for a in attachments if (a.path or a.url)
]
if attachment_types:
to_save["attachment_types"] = [
{"type": a.type, "value": a.path or a.url}
for a in attachment_types
if (a.path or a.url)
]
if options:
# Need to validate and convert their types first
model = get_model(model_id or get_default_model())
@ -568,7 +609,16 @@ def prompt(
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
model_id = template_obj.model
# Merge in any attachments
if template_obj.attachments:
attachments = [
resolve_attachment(a) for a in template_obj.attachments
] + list(attachments)
if template_obj.attachment_types:
attachment_types = [
resolve_attachment_with_type(at.value, at.type)
for at in template_obj.attachment_types
] + list(attachment_types)
if extract or extract_last:
no_stream = True

View file

@ -3,10 +3,17 @@ import string
from typing import Optional, Any, Dict, List, Tuple
class AttachmentType(BaseModel):
type: str
value: str
class Template(BaseModel):
name: str
prompt: Optional[str] = None
system: Optional[str] = None
attachments: Optional[List[str]] = None
attachment_types: Optional[List[AttachmentType]] = None
model: Optional[str] = None
defaults: Optional[Dict[str, Any]] = None
options: Optional[Dict[str, Any]] = None

View file

@ -4,6 +4,7 @@ from llm import Template
from llm.cli import cli
import os
from unittest import mock
import pathlib
import pytest
import yaml
@ -78,7 +79,7 @@ def test_templates_list(templates_path, args):
@pytest.mark.parametrize(
"args,expected_prompt,expected_error",
"args,expected,expected_error",
(
(["-m", "gpt4", "hello"], {"model": "gpt-4", "prompt": "hello"}, None),
(["hello $foo"], {"prompt": "hello $foo"}, None),
@ -126,18 +127,36 @@ def test_templates_list(templates_path, args):
},
None,
),
# And attachments and attachment_types
(
["--attachment", "a.txt", "--attachment-type", "b.txt", "text/plain"],
{
"attachments": ["a.txt"],
"attachment_types": [{"type": "text/plain", "value": "b.txt"}],
},
None,
),
),
)
def test_templates_prompt_save(templates_path, args, expected_prompt, expected_error):
def test_templates_prompt_save(templates_path, args, expected, expected_error):
assert not (templates_path / "saved.yaml").exists()
runner = CliRunner()
result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False)
with runner.isolated_filesystem():
# Create a file to test attachment
pathlib.Path("a.txt").write_text("attachment", "utf-8")
pathlib.Path("b.txt").write_text("attachment type", "utf-8")
result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False)
if not expected_error:
assert result.exit_code == 0
assert (
yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
== expected_prompt
)
yaml_data = yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
# Adjust attachment and attachment_types paths to be just the filename
if "attachments" in yaml_data:
yaml_data["attachments"] = [
os.path.basename(path) for path in yaml_data["attachments"]
]
for item in yaml_data.get("attachment_types", []):
item["value"] = os.path.basename(item["value"])
assert yaml_data == expected
else:
assert result.exit_code == 1
assert expected_error in result.output