mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-18 10:41:06 +00:00
Attachment and attachment type support for templates, closes #826
This commit is contained in:
parent
d0255a1eda
commit
388a4ea9a2
3 changed files with 121 additions and 45 deletions
126
llm/cli.py
126
llm/cli.py
|
|
@ -151,54 +151,84 @@ def resolve_fragments(
|
|||
return resolved
|
||||
|
||||
|
||||
class AttachmentError(Exception):
|
||||
"""Exception raised for errors in attachment resolution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def resolve_attachment(value):
|
||||
"""
|
||||
Resolve an attachment from a string value which could be:
|
||||
- "-" for stdin
|
||||
- A URL
|
||||
- A file path
|
||||
|
||||
Returns an Attachment object.
|
||||
Raises AttachmentError if the attachment cannot be resolved.
|
||||
"""
|
||||
if value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
# Try to guess type
|
||||
mimetype = mimetype_from_string(content)
|
||||
if mimetype is None:
|
||||
raise AttachmentError("Could not determine mimetype of stdin")
|
||||
return Attachment(type=mimetype, path=None, url=None, content=content)
|
||||
|
||||
if "://" in value:
|
||||
# Confirm URL exists and try to guess type
|
||||
try:
|
||||
response = httpx.head(value)
|
||||
response.raise_for_status()
|
||||
mimetype = response.headers.get("content-type")
|
||||
except httpx.HTTPError as ex:
|
||||
raise AttachmentError(str(ex))
|
||||
return Attachment(type=mimetype, path=None, url=value, content=None)
|
||||
|
||||
# Check that the file exists
|
||||
path = pathlib.Path(value)
|
||||
if not path.exists():
|
||||
raise AttachmentError(f"File {value} does not exist")
|
||||
path = path.resolve()
|
||||
|
||||
# Try to guess type
|
||||
mimetype = mimetype_from_path(str(path))
|
||||
if mimetype is None:
|
||||
raise AttachmentError(f"Could not determine mimetype of {value}")
|
||||
|
||||
return Attachment(type=mimetype, path=str(path), url=None, content=None)
|
||||
|
||||
|
||||
class AttachmentType(click.ParamType):
|
||||
name = "attachment"
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
if value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
# Try to guess type
|
||||
mimetype = mimetype_from_string(content)
|
||||
if mimetype is None:
|
||||
raise click.BadParameter("Could not determine mimetype of stdin")
|
||||
return Attachment(type=mimetype, path=None, url=None, content=content)
|
||||
if "://" in value:
|
||||
# Confirm URL exists and try to guess type
|
||||
try:
|
||||
response = httpx.head(value)
|
||||
response.raise_for_status()
|
||||
mimetype = response.headers.get("content-type")
|
||||
except httpx.HTTPError as ex:
|
||||
raise click.BadParameter(str(ex))
|
||||
return Attachment(mimetype, None, value, None)
|
||||
# Check that the file exists
|
||||
try:
|
||||
return resolve_attachment(value)
|
||||
except AttachmentError as e:
|
||||
self.fail(str(e), param, ctx)
|
||||
|
||||
|
||||
def resolve_attachment_with_type(value: str, mimetype: str) -> Attachment:
|
||||
if "://" in value:
|
||||
attachment = Attachment(mimetype, None, value, None)
|
||||
elif value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
attachment = Attachment(mimetype, None, None, content)
|
||||
else:
|
||||
# Look for file
|
||||
path = pathlib.Path(value)
|
||||
if not path.exists():
|
||||
self.fail(f"File {value} does not exist", param, ctx)
|
||||
raise click.BadParameter(f"File {value} does not exist")
|
||||
path = path.resolve()
|
||||
# Try to guess type
|
||||
mimetype = mimetype_from_path(str(path))
|
||||
if mimetype is None:
|
||||
raise click.BadParameter(f"Could not determine mimetype of {value}")
|
||||
return Attachment(type=mimetype, path=str(path), url=None, content=None)
|
||||
attachment = Attachment(mimetype, str(path), None, None)
|
||||
return attachment
|
||||
|
||||
|
||||
def attachment_types_callback(ctx, param, values):
|
||||
def attachment_types_callback(ctx, param, values) -> List[Attachment]:
|
||||
collected = []
|
||||
for value, mimetype in values:
|
||||
if "://" in value:
|
||||
attachment = Attachment(mimetype, None, value, None)
|
||||
elif value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
attachment = Attachment(mimetype, None, None, content)
|
||||
else:
|
||||
# Look for file
|
||||
path = pathlib.Path(value)
|
||||
if not path.exists():
|
||||
raise click.BadParameter(f"File {value} does not exist")
|
||||
path = path.resolve()
|
||||
attachment = Attachment(mimetype, str(path), None, None)
|
||||
collected.append(attachment)
|
||||
collected.append(resolve_attachment_with_type(value, mimetype))
|
||||
return collected
|
||||
|
||||
|
||||
|
|
@ -508,6 +538,17 @@ def prompt(
|
|||
to_save["fragments"] = list(fragments)
|
||||
if system_fragments:
|
||||
to_save["system_fragments"] = list(system_fragments)
|
||||
if attachments:
|
||||
# Only works for attachments with a path or url
|
||||
to_save["attachments"] = [
|
||||
(a.path or a.url) for a in attachments if (a.path or a.url)
|
||||
]
|
||||
if attachment_types:
|
||||
to_save["attachment_types"] = [
|
||||
{"type": a.type, "value": a.path or a.url}
|
||||
for a in attachment_types
|
||||
if (a.path or a.url)
|
||||
]
|
||||
if options:
|
||||
# Need to validate and convert their types first
|
||||
model = get_model(model_id or get_default_model())
|
||||
|
|
@ -568,7 +609,16 @@ def prompt(
|
|||
raise click.ClickException(str(ex))
|
||||
if model_id is None and template_obj.model:
|
||||
model_id = template_obj.model
|
||||
|
||||
# Merge in any attachments
|
||||
if template_obj.attachments:
|
||||
attachments = [
|
||||
resolve_attachment(a) for a in template_obj.attachments
|
||||
] + list(attachments)
|
||||
if template_obj.attachment_types:
|
||||
attachment_types = [
|
||||
resolve_attachment_with_type(at.value, at.type)
|
||||
for at in template_obj.attachment_types
|
||||
] + list(attachment_types)
|
||||
if extract or extract_last:
|
||||
no_stream = True
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,17 @@ import string
|
|||
from typing import Optional, Any, Dict, List, Tuple
|
||||
|
||||
|
||||
class AttachmentType(BaseModel):
|
||||
type: str
|
||||
value: str
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
name: str
|
||||
prompt: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
attachments: Optional[List[str]] = None
|
||||
attachment_types: Optional[List[AttachmentType]] = None
|
||||
model: Optional[str] = None
|
||||
defaults: Optional[Dict[str, Any]] = None
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from llm import Template
|
|||
from llm.cli import cli
|
||||
import os
|
||||
from unittest import mock
|
||||
import pathlib
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
|
@ -78,7 +79,7 @@ def test_templates_list(templates_path, args):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected_prompt,expected_error",
|
||||
"args,expected,expected_error",
|
||||
(
|
||||
(["-m", "gpt4", "hello"], {"model": "gpt-4", "prompt": "hello"}, None),
|
||||
(["hello $foo"], {"prompt": "hello $foo"}, None),
|
||||
|
|
@ -126,18 +127,36 @@ def test_templates_list(templates_path, args):
|
|||
},
|
||||
None,
|
||||
),
|
||||
# And attachments and attachment_types
|
||||
(
|
||||
["--attachment", "a.txt", "--attachment-type", "b.txt", "text/plain"],
|
||||
{
|
||||
"attachments": ["a.txt"],
|
||||
"attachment_types": [{"type": "text/plain", "value": "b.txt"}],
|
||||
},
|
||||
None,
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_templates_prompt_save(templates_path, args, expected_prompt, expected_error):
|
||||
def test_templates_prompt_save(templates_path, args, expected, expected_error):
|
||||
assert not (templates_path / "saved.yaml").exists()
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False)
|
||||
with runner.isolated_filesystem():
|
||||
# Create a file to test attachment
|
||||
pathlib.Path("a.txt").write_text("attachment", "utf-8")
|
||||
pathlib.Path("b.txt").write_text("attachment type", "utf-8")
|
||||
result = runner.invoke(cli, args + ["--save", "saved"], catch_exceptions=False)
|
||||
if not expected_error:
|
||||
assert result.exit_code == 0
|
||||
assert (
|
||||
yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
|
||||
== expected_prompt
|
||||
)
|
||||
yaml_data = yaml.safe_load((templates_path / "saved.yaml").read_text("utf-8"))
|
||||
# Adjust attachment and attachment_types paths to be just the filename
|
||||
if "attachments" in yaml_data:
|
||||
yaml_data["attachments"] = [
|
||||
os.path.basename(path) for path in yaml_data["attachments"]
|
||||
]
|
||||
for item in yaml_data.get("attachment_types", []):
|
||||
item["value"] = os.path.basename(item["value"])
|
||||
assert yaml_data == expected
|
||||
else:
|
||||
assert result.exit_code == 1
|
||||
assert expected_error in result.output
|
||||
|
|
|
|||
Loading…
Reference in a new issue