First working prototype of new attachments feature, refs #587

This commit is contained in:
Simon Willison 2024-10-26 17:40:23 -07:00
parent a466ddf3cd
commit 6df00f92ff
5 changed files with 193 additions and 14 deletions

View file

@ -4,6 +4,7 @@ from .errors import (
NeedsKeyException,
)
from .models import (
Attachment,
Conversation,
Model,
ModelWithAliases,

View file

@ -4,6 +4,7 @@ from dataclasses import asdict
import io
import json
from llm import (
Attachment,
Collection,
Conversation,
Response,
@ -30,7 +31,9 @@ from llm import (
from .migrations import migrate
from .plugins import pm
import base64
import httpx
import pathlib
import puremagic
import pydantic
import readline
from runpy import run_module
@ -48,6 +51,54 @@ warnings.simplefilter("ignore", ResourceWarning)
DEFAULT_TEMPLATE = "prompt: "
class AttachmentType(click.ParamType):
name = "attachment"
def convert(self, value, param, ctx):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
try:
mimetype = puremagic.from_string(content, mime=True)
except puremagic.PureError:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(mimetype, None, None, content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
response = httpx.head(value)
response.raise_for_status()
mimetype = response.headers.get("content-type")
except httpx.HTTPError as ex:
raise click.BadParameter(str(ex))
return Attachment(mimetype, None, value, None)
# Check that the file exists
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
return Attachment(mimetype, str(path), None, None)
def attachment_types_callback(ctx, param, values):
collected = []
for value, mimetype in values:
if "://" in value:
attachment = Attachment(mimetype, None, value, None)
elif value == "-":
content = sys.stdin.buffer.read()
attachment = Attachment(mimetype, None, None, content)
else:
# Look for file
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
return collected
def _validate_metadata_json(ctx, param, value):
if value is None:
return value
@ -88,6 +139,23 @@ def cli():
@click.argument("prompt", required=False)
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
@click.option(
"attachments",
"-a",
"--attachment",
type=AttachmentType(),
multiple=True,
help="Attachment path or URL or -",
)
@click.option(
"attachment_types",
"--at",
"--attachment-type",
type=(str, str),
multiple=True,
callback=attachment_types_callback,
help="Attachment with explicit mimetype",
)
@click.option(
"options",
"-o",
@ -127,6 +195,8 @@ def prompt(
prompt,
system,
model_id,
attachments,
attachment_types,
options,
template,
param,
@ -142,6 +212,14 @@ def prompt(
Execute a prompt
Documentation: https://llm.datasette.io/en/stable/usage.html
Examples:
\b
llm 'Capital of France?'
llm 'Capital of France?' -m gpt-4o
llm 'Capital of France?' -s 'answer in Spanish'
llm 'Extract text from this image' -a image.jpg
"""
if log and no_log:
raise click.ClickException("--log and --no-log are mutually exclusive")
@ -262,6 +340,8 @@ def prompt(
except pydantic.ValidationError as ex:
raise click.ClickException(render_errors(ex.errors()))
resolved_attachments = [*attachments, *attachment_types]
should_stream = model.can_stream and not no_stream
if not should_stream:
validated_options["stream"] = False
@ -273,7 +353,9 @@ def prompt(
prompt_method = conversation.prompt
try:
response = prompt_method(prompt, system, **validated_options)
response = prompt_method(
prompt, *resolved_attachments, system=system, **validated_options
)
if should_stream:
for chunk in response:
print(chunk, end="")

View file

@ -33,8 +33,8 @@ def register_models(register):
register(Chat("gpt-4-turbo-2024-04-09"))
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
# GPT-4o
register(Chat("gpt-4o"), aliases=("4o",))
register(Chat("gpt-4o-mini"), aliases=("4o-mini",))
register(Chat("gpt-4o", vision=True), aliases=("4o",))
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
# o1
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
@ -271,6 +271,7 @@ class Chat(Model):
api_engine=None,
headers=None,
can_stream=True,
vision=False,
allows_system_prompt=True,
):
self.model_id = model_id
@ -282,8 +283,17 @@ class Chat(Model):
self.api_engine = api_engine
self.headers = headers
self.can_stream = can_stream
self.vision = vision
self.allows_system_prompt = allows_system_prompt
if vision:
self.attachment_types = {
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
}
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)
@ -308,7 +318,18 @@ class Chat(Model):
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
if not prompt.attachments:
messages.append({"role": "user", "content": prompt.prompt})
else:
vision_message = [{"type": "text", "text": prompt.prompt}]
for attachment in prompt.attachments:
url = attachment.url
if not url:
base64_image = attachment.base64_content()
url = f"data:{attachment.resolve_type()};base64,{base64_image}"
vision_message.append({"type": "image_url", "image_url": {"url": url}})
messages.append({"role": "user", "content": vision_message})
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)
client = self.get_client()

View file

@ -1,7 +1,10 @@
import base64
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import httpx
from itertools import islice
import puremagic
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
@ -13,17 +16,52 @@ from ulid import ULID
CONVERSATION_NAME_LENGTH = 32
@dataclass
class Attachment:
type: Optional[str] = None
path: Optional[str] = None
url: Optional[str] = None
content: Optional[bytes] = None
def resolve_type(self):
if self.type:
return self.type
# Derive it from path or url or content
if self.path:
return puremagic.from_file(self.path, mime=True)
if self.url:
return puremagic.from_url(self.url, mime=True)
if self.content:
return puremagic.from_string(self.content, mime=True)
raise ValueError("Attachment has no type and no content to derive it from")
def base64_content(self):
content = self.content
if not content:
if self.path:
content = open(self.path, "rb").read()
elif self.url:
response = httpx.get(self.url)
response.raise_for_status()
content = response.content
return base64.b64encode(content).decode("utf-8")
@dataclass
class Prompt:
prompt: str
model: "Model"
system: Optional[str]
prompt_json: Optional[str]
options: "Options"
attachments: Optional[List[Attachment]] = field(default_factory=list)
system: Optional[str] = None
prompt_json: Optional[str] = None
options: "Options" = field(default_factory=dict)
def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
def __init__(
self, prompt, model, attachments, system=None, prompt_json=None, options=None
):
self.prompt = prompt
self.model = model
self.attachments = list(attachments)
self.system = system
self.prompt_json = prompt_json
self.options = options or {}
@ -39,6 +77,7 @@ class Conversation:
def prompt(
self,
prompt: Optional[str],
*attachments: Attachment,
system: Optional[str] = None,
stream: bool = True,
**options
@ -46,8 +85,9 @@ class Conversation:
return Response(
Prompt(
prompt,
system=system,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
@ -158,14 +198,22 @@ class Response(ABC):
db["responses"].insert(response)
@classmethod
def fake(cls, model: "Model", prompt: str, system: str, response: str):
def fake(
cls,
model: "Model",
prompt: str,
*attachments: List[Attachment],
system: str,
response: str
):
"Utility method to help with writing tests"
response_obj = cls(
model=model,
prompt=Prompt(
prompt,
system=system,
model=model,
attachments=attachments,
system=system,
),
stream=False,
)
@ -183,8 +231,9 @@ class Response(ABC):
model=model,
prompt=Prompt(
prompt=row["prompt"],
system=row["system"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
@ -242,10 +291,15 @@ class _get_key_mixin:
class Model(ABC, _get_key_mixin):
model_id: str
# API key handling
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
# Model characteristics
can_stream: bool = False
attachment_types = set()
class Options(_Options):
pass
@ -269,13 +323,33 @@ class Model(ABC, _get_key_mixin):
def prompt(
self,
prompt: Optional[str],
prompt: str,
*attachments: Attachment,
system: Optional[str] = None,
stream: bool = True,
**options
):
# Validate attachments
if attachments and not self.attachment_types:
raise ValueError(
"This model does not support attachments, but some were provided"
)
for attachment in attachments:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
"This model does not support attachments of type '{}', only {}".format(
attachment_type, ", ".join(self.attachment_types)
)
)
return self.response(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
stream=stream,
)

View file

@ -48,6 +48,7 @@ setup(
"setuptools",
"pip",
"pyreadline3; sys_platform == 'win32'",
"puremagic",
],
extras_require={
"test": [