From c0fe719df65232665b071abd29f2df26e99c8b9d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 26 Oct 2024 19:48:57 -0700 Subject: [PATCH] Store prompt attachments in attachments and prompt_attachments tables Refs https://github.com/simonw/llm/issues/587#issuecomment-2439791231 --- llm/cli.py | 2 ++ llm/migrations.py | 26 ++++++++++++++++++++++++++ llm/models.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/llm/cli.py b/llm/cli.py index 33e14f0..454e56c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -76,6 +76,7 @@ class AttachmentType(click.ParamType): path = pathlib.Path(value) if not path.exists(): self.fail(f"File {value} does not exist", param, ctx) + path = path.resolve() # Try to guess type mimetype = puremagic.from_file(str(path), mime=True) return Attachment(mimetype, str(path), None, None) @@ -94,6 +95,7 @@ def attachment_types_callback(ctx, param, values): path = pathlib.Path(value) if not path.exists(): raise click.BadParameter(f"File {value} does not exist") + path = path.resolve() attachment = Attachment(mimetype, str(path), None, None) collected.append(attachment) return collected diff --git a/llm/migrations.py b/llm/migrations.py index 008ae97..91da642 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -201,3 +201,29 @@ def m010_create_new_log_tables(db): @migration def m011_fts_for_responses(db): db["responses"].enable_fts(["prompt", "response"], create_triggers=True) + + +@migration +def m012_attachments_tables(db): + db["attachments"].create( + { + "id": str, + "type": str, + "path": str, + "url": str, + "content": bytes, + }, + pk="id", + ) + db["prompt_attachments"].create( + { + "response_id": str, + "attachment_id": str, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("attachment_id", "attachments", "id"), + ), + pk=("response_id", "attachment_id"), + ) diff --git a/llm/models.py b/llm/models.py index 77bdb8e..3b7d4da 100644 --- a/llm/models.py +++ b/llm/models.py @@ -2,6 +2,7 @@ import base64 from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException +import hashlib import httpx from itertools import islice import puremagic @@ -23,6 +24,17 @@ class Attachment: url: Optional[str] = None content: Optional[bytes] = None + def hash_id(self): + # Hash of the binary content, or of '{"url": "https://..."}' for URL attachments + if self.content: + return hashlib.sha256(self.content).hexdigest() + elif self.path: + return hashlib.sha256(open(self.path, "rb").read()).hexdigest() + else: + return hashlib.sha256( + json.dumps({"url": self.url}).encode("utf-8") + ).hexdigest() + def resolve_type(self): if self.type: return self.type @@ -178,8 +190,9 @@ class Response(ABC): }, ignore=True, ) + response_id = str(ULID()).lower() response = { - "id": str(ULID()).lower(), + "id": response_id, "model": self.model.model_id, "prompt": self.prompt.prompt, "system": self.prompt.system, @@ -196,6 +209,26 @@ class Response(ABC): "datetime_utc": self.datetime_utc(), } db["responses"].insert(response) + # Persist any attachments - loop through with index + for index, attachment in enumerate(self.prompt.attachments): + attachment_id = attachment.hash_id() + db["attachments"].insert( + { + "id": attachment_id, + "type": attachment.resolve_type(), + "path": attachment.path, + "url": attachment.url, + "content": attachment.content, + }, + replace=True, + ) + db["prompt_attachments"].insert( + { + "response_id": response_id, + "attachment_id": attachment_id, + "order": index, + }, + ) @classmethod def fake(