From 388a4ea9a2f7ff2787ab40471470419433857391 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 7 Apr 2025 07:20:49 -0700 Subject: [PATCH] Attachment and attachment type support for templates, closes #826 --- llm/cli.py | 126 ++++++++++++++++++++++++++++------------ llm/templates.py | 7 +++ tests/test_templates.py | 33 ++++++++--- 3 files changed, 121 insertions(+), 45 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 56eecc2..dfb5290 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -151,54 +151,84 @@ def resolve_fragments( return resolved +class AttachmentError(Exception): + """Exception raised for errors in attachment resolution.""" + + pass + + +def resolve_attachment(value): + """ + Resolve an attachment from a string value which could be: + - "-" for stdin + - A URL + - A file path + + Returns an Attachment object. + Raises AttachmentError if the attachment cannot be resolved. + """ + if value == "-": + content = sys.stdin.buffer.read() + # Try to guess type + mimetype = mimetype_from_string(content) + if mimetype is None: + raise AttachmentError("Could not determine mimetype of stdin") + return Attachment(type=mimetype, path=None, url=None, content=content) + + if "://" in value: + # Confirm URL exists and try to guess type + try: + response = httpx.head(value) + response.raise_for_status() + mimetype = response.headers.get("content-type") + except httpx.HTTPError as ex: + raise AttachmentError(str(ex)) + return Attachment(type=mimetype, path=None, url=value, content=None) + + # Check that the file exists + path = pathlib.Path(value) + if not path.exists(): + raise AttachmentError(f"File {value} does not exist") + path = path.resolve() + + # Try to guess type + mimetype = mimetype_from_path(str(path)) + if mimetype is None: + raise AttachmentError(f"Could not determine mimetype of {value}") + + return Attachment(type=mimetype, path=str(path), url=None, content=None) + + class AttachmentType(click.ParamType): name = "attachment" def convert(self, value, param, ctx): - if value == "-": - content = sys.stdin.buffer.read() - # Try to guess type - mimetype = mimetype_from_string(content) - if mimetype is None: - raise click.BadParameter("Could not determine mimetype of stdin") - return Attachment(type=mimetype, path=None, url=None, content=content) - if "://" in value: - # Confirm URL exists and try to guess type - try: - response = httpx.head(value) - response.raise_for_status() - mimetype = response.headers.get("content-type") - except httpx.HTTPError as ex: - raise click.BadParameter(str(ex)) - return Attachment(mimetype, None, value, None) - # Check that the file exists + try: + return resolve_attachment(value) + except AttachmentError as e: + self.fail(str(e), param, ctx) + + +def resolve_attachment_with_type(value: str, mimetype: str) -> Attachment: + if "://" in value: + attachment = Attachment(mimetype, None, value, None) + elif value == "-": + content = sys.stdin.buffer.read() + attachment = Attachment(mimetype, None, None, content) + else: + # Look for file path = pathlib.Path(value) if not path.exists(): - self.fail(f"File {value} does not exist", param, ctx) + raise click.BadParameter(f"File {value} does not exist") path = path.resolve() - # Try to guess type - mimetype = mimetype_from_path(str(path)) - if mimetype is None: - raise click.BadParameter(f"Could not determine mimetype of {value}") - return Attachment(type=mimetype, path=str(path), url=None, content=None) + attachment = Attachment(mimetype, str(path), None, None) + return attachment -def attachment_types_callback(ctx, param, values): +def attachment_types_callback(ctx, param, values) -> List[Attachment]: collected = [] for value, mimetype in values: - if "://" in value: - attachment = Attachment(mimetype, None, value, None) - elif value == "-": - content = sys.stdin.buffer.read() - attachment = Attachment(mimetype, None, None, content) - else: - # Look for file - path = pathlib.Path(value) - if not path.exists(): - raise click.BadParameter(f"File {value} does not exist") - path = path.resolve() - attachment = Attachment(mimetype, str(path), None, None) - collected.append(attachment) + collected.append(resolve_attachment_with_type(value, mimetype)) return collected @@ -508,6 +538,17 @@ def prompt( to_save["fragments"] = list(fragments) if system_fragments: to_save["system_fragments"] = list(system_fragments) + if attachments: + # Only works for attachments with a path or url + to_save["attachments"] = [ + (a.path or a.url) for a in attachments if (a.path or a.url) + ] + if attachment_types: + to_save["attachment_types"] = [ + {"type": a.type, "value": a.path or a.url} + for a in attachment_types + if (a.path or a.url) + ] if options: # Need to validate and convert their types first model = get_model(model_id or get_default_model()) @@ -568,7 +609,16 @@ def prompt( raise click.ClickException(str(ex)) if model_id is None and template_obj.model: model_id = template_obj.model - + # Merge in any attachments + if template_obj.attachments: + attachments = [ + resolve_attachment(a) for a in template_obj.attachments + ] + list(attachments) + if template_obj.attachment_types: + attachment_types = [ + resolve_attachment_with_type(at.value, at.type) + for at in template_obj.attachment_types + ] + list(attachment_types) if extract or extract_last: no_stream = True diff --git a/llm/templates.py b/llm/templates.py index 0544408..9291847 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -3,10 +3,17 @@ import string from typing import Optional, Any, Dict, List, Tuple +class AttachmentType(BaseModel): + type: str + value: str + + class Template(BaseModel): name: str prompt: Optional[str] = None system: Optional[str] = None + attachments: Optional[List[str]] = None + attachment_types: Optional[List[AttachmentType]] = None model: Optional[str] = None defaults: Optional[Dict[str, Any]] = None options: Optional[Dict[str, Any]] = None diff --git a/tests/test_templates.py b/tests/test_templates.py index f16c102..76972db 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,6 +4,7 @@ from llm import Template from llm.cli import cli import os from unittest import mock +import pathlib import pytest import yaml @@ -78,7 +79,7 @@ def test_templates_list(templates_path, args): @pytest.mark.parametrize( - "args,expected_prompt,expected_error", + "args,expected,expected_error", ( (["-m", "gpt4", "hello"], {"model": "gpt-4", "prompt": "hello"}, None), (["hello $foo"], {"prompt": "hello $foo"}, None), @@ -126,18 +127,36 @@ def test_templates_list(templates_path, args): }, None, ), + # And attachments and attachment_types + ( + ["--attachment", "a.txt", "--attachment-type", "b.txt", "text/plain"], + { + "attachments": ["a.txt"], + "attachment_types": [{"type": "text/plain", "value": "b.txt"}], + }, + None, + ), ), ) -def test_templates_prompt_save(templates_path, args, expected_prompt, expected_error): +def test_templates_prompt_save(templates_path, args, expected, expected_error): assert not (templates_path / "saved.yaml").exists() runner = CliRunner() - result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False) + with runner.isolated_filesystem(): + # Create a file to test attachment + pathlib.Path("a.txt").write_text("attachment", "utf-8") + pathlib.Path("b.txt").write_text("attachment type", "utf-8") + result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False) if not expected_error: assert result.exit_code == 0 - assert ( - yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8")) - == expected_prompt - ) + yaml_data = yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8")) + # Adjust attachment and attachment_types paths to be just the filename + if "attachments" in yaml_data: + yaml_data["attachments"] = [ + os.path.basename(path) for path in yaml_data["attachments"] + ] + for item in yaml_data.get("attachment_types", []): + item["value"] = os.path.basename(item["value"]) + assert yaml_data == expected else: assert result.exit_code == 1 assert expected_error in result.output