Store prompt attachments in attachments and prompt_attachments tables

Refs https://github.com/simonw/llm/issues/587#issuecomment-2439791231
This commit is contained in:
Simon Willison 2024-10-26 19:48:57 -07:00
parent 6df00f92ff
commit c0fe719df6
3 changed files with 62 additions and 1 deletions

View file

@ -76,6 +76,7 @@ class AttachmentType(click.ParamType):
path = pathlib.Path(value)
if not path.exists():
self.fail(f"File {value} does not exist", param, ctx)
path = path.resolve()
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
return Attachment(mimetype, str(path), None, None)
@ -94,6 +95,7 @@ def attachment_types_callback(ctx, param, values):
path = pathlib.Path(value)
if not path.exists():
raise click.BadParameter(f"File {value} does not exist")
path = path.resolve()
attachment = Attachment(mimetype, str(path), None, None)
collected.append(attachment)
return collected

View file

@ -201,3 +201,29 @@ def m010_create_new_log_tables(db):
@migration
def m011_fts_for_responses(db):
db["responses"].enable_fts(["prompt", "response"], create_triggers=True)
@migration
def m012_attachments_tables(db):
db["attachments"].create(
{
"id": str,
"type": str,
"path": str,
"url": str,
"content": bytes,
},
pk="id",
)
db["prompt_attachments"].create(
{
"response_id": str,
"attachment_id": str,
"order": int,
},
foreign_keys=(
("response_id", "responses", "id"),
("attachment_id", "attachments", "id"),
),
pk=("response_id", "attachment_id"),
)

View file

@ -2,6 +2,7 @@ import base64
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import hashlib
import httpx
from itertools import islice
import puremagic
@ -23,6 +24,17 @@ class Attachment:
url: Optional[str] = None
content: Optional[bytes] = None
def hash_id(self):
# Hash of the binary content, or of '{"url": "https://..."}' for URL attachments
if self.content:
return hashlib.sha256(self.content).hexdigest()
elif self.path:
return hashlib.sha256(open(self.path, "rb").read()).hexdigest()
else:
return hashlib.sha256(
json.dumps({"url": self.url}).encode("utf-8")
).hexdigest()
def resolve_type(self):
if self.type:
return self.type
@ -178,8 +190,9 @@ class Response(ABC):
},
ignore=True,
)
response_id = str(ULID()).lower()
response = {
"id": str(ULID()).lower(),
"id": response_id,
"model": self.model.model_id,
"prompt": self.prompt.prompt,
"system": self.prompt.system,
@ -196,6 +209,26 @@ class Response(ABC):
"datetime_utc": self.datetime_utc(),
}
db["responses"].insert(response)
# Persist any attachments - loop through with index
for index, attachment in enumerate(self.prompt.attachments):
attachment_id = attachment.hash_id()
db["attachments"].insert(
{
"id": attachment_id,
"type": attachment.resolve_type(),
"path": attachment.path,
"url": attachment.url,
"content": attachment.content,
},
replace=True,
)
db["prompt_attachments"].insert(
{
"response_id": response_id,
"attachment_id": attachment_id,
"order": index,
},
)
@classmethod
def fake(