diff --git a/llm/cli.py b/llm/cli.py index 454e56c..53087ad 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -62,7 +62,7 @@ class AttachmentType(click.ParamType): mimetype = puremagic.from_string(content, mime=True) except puremagic.PureError: raise click.BadParameter("Could not determine mimetype of stdin") - return Attachment(mimetype, None, None, content) + return Attachment(type=mimetype, path=None, url=None, content=content) if "://" in value: # Confirm URL exists and try to guess type try: @@ -79,7 +79,7 @@ class AttachmentType(click.ParamType): path = path.resolve() # Try to guess type mimetype = puremagic.from_file(str(path), mime=True) - return Attachment(mimetype, str(path), None, None) + return Attachment(type=mimetype, path=str(path), url=None, content=None) def attachment_types_callback(ctx, param, values): @@ -552,7 +552,7 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]: for response in db["responses"].rows_where( "conversation_id = ?", [conversation_id] ): - conversation.responses.append(Response.from_row(response)) + conversation.responses.append(Response.from_row(db, response)) return conversation diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 913e754..651d37f 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -312,23 +312,39 @@ class Chat(Model): {"role": "system", "content": prev_response.prompt.system} ) current_system = prev_response.prompt.system - messages.append( - {"role": "user", "content": prev_response.prompt.prompt} - ) + if prev_response.attachments: + attachment_message = [ + {"type": "text", "text": prev_response.prompt.prompt} + ] + for attachment in prev_response.attachments: + url = attachment.url + if not url: + base64_image = attachment.base64_content() + url = f"data:{attachment.resolve_type()};base64,{base64_image}" + attachment_message.append( + {"type": "image_url", "image_url": {"url": url}} + ) + messages.append({"role": "user", "content": attachment_message}) + else: + messages.append( + {"role": "user", "content": prev_response.prompt.prompt} + ) messages.append({"role": "assistant", "content": prev_response.text()}) if prompt.system and prompt.system != current_system: messages.append({"role": "system", "content": prompt.system}) if not prompt.attachments: messages.append({"role": "user", "content": prompt.prompt}) else: - vision_message = [{"type": "text", "text": prompt.prompt}] + attachment_message = [{"type": "text", "text": prompt.prompt}] for attachment in prompt.attachments: url = attachment.url if not url: base64_image = attachment.base64_content() url = f"data:{attachment.resolve_type()};base64,{base64_image}" - vision_message.append({"type": "image_url", "image_url": {"url": url}}) - messages.append({"role": "user", "content": vision_message}) + attachment_message.append( + {"type": "image_url", "image_url": {"url": url}} + ) + messages.append({"role": "user", "content": attachment_message}) response._prompt_json = {"messages": messages} kwargs = self.build_kwargs(prompt) diff --git a/llm/models.py b/llm/models.py index 3b7d4da..fc58a07 100644 --- a/llm/models.py +++ b/llm/models.py @@ -23,17 +23,20 @@ class Attachment: path: Optional[str] = None url: Optional[str] = None content: Optional[bytes] = None + _id: Optional[str] = None - def hash_id(self): + def id(self): # Hash of the binary content, or of '{"url": "https://..."}' for URL attachments - if self.content: - return hashlib.sha256(self.content).hexdigest() - elif self.path: - return hashlib.sha256(open(self.path, "rb").read()).hexdigest() - else: - return hashlib.sha256( - json.dumps({"url": self.url}).encode("utf-8") - ).hexdigest() + if self._id is None: + if self.content: + self._id = hashlib.sha256(self.content).hexdigest() + elif self.path: + self._id = hashlib.sha256(open(self.path, "rb").read()).hexdigest() + else: + self._id = hashlib.sha256( + json.dumps({"url": self.url}).encode("utf-8") + ).hexdigest() + return self._id def resolve_type(self): if self.type: @@ -58,6 +61,16 @@ class Attachment: content = response.content return base64.b64encode(content).decode("utf-8") + @classmethod + def from_row(cls, row): + return cls( + _id=row["id"], + type=row["type"], + path=row["path"], + url=row["url"], + content=row["content"], + ) + @dataclass class Prompt: @@ -211,7 +224,7 @@ class Response(ABC): db["responses"].insert(response) # Persist any attachments - loop through with index for index, attachment in enumerate(self.prompt.attachments): - attachment_id = attachment.hash_id() + attachment_id = attachment.id() db["attachments"].insert( { "id": attachment_id, @@ -255,7 +268,7 @@ class Response(ABC): return response_obj @classmethod - def from_row(cls, row): + def from_row(cls, db, row): from llm import get_model model = get_model(row["model"]) @@ -276,6 +289,19 @@ class Response(ABC): response.response_json = json.loads(row["response_json"] or "null") response._done = True response._chunks = [row["response"]] + # Attachments + response.attachments = [ + Attachment.from_row(arow) + for arow in db.query( + """ + select attachments.* from attachments + join prompt_attachments on attachments.id = prompt_attachments.attachment_id + where prompt_attachments.response_id = ? + order by prompt_attachments."order" + """, + [row["id"]], + ) + ] return response def __repr__(self):