Got llm --continue to work with images, refs #587

This commit is contained in:
Simon Willison 2024-10-28 12:16:54 -07:00
parent c0fe719df6
commit dff5b456fd
3 changed files with 62 additions and 20 deletions

View file

@ -62,7 +62,7 @@ class AttachmentType(click.ParamType):
mimetype = puremagic.from_string(content, mime=True)
except puremagic.PureError:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(mimetype, None, None, content)
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
# Confirm URL exists and try to guess type
try:
@ -79,7 +79,7 @@ class AttachmentType(click.ParamType):
path = path.resolve()
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
return Attachment(mimetype, str(path), None, None)
return Attachment(type=mimetype, path=str(path), url=None, content=None)
def attachment_types_callback(ctx, param, values):
@ -552,7 +552,7 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
for response in db["responses"].rows_where(
"conversation_id = ?", [conversation_id]
):
conversation.responses.append(Response.from_row(response))
conversation.responses.append(Response.from_row(db, response))
return conversation

View file

@ -312,23 +312,39 @@ class Chat(Model):
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
messages.append(
{"role": "user", "content": prev_response.prompt.prompt}
)
if prev_response.attachments:
attachment_message = [
{"type": "text", "text": prev_response.prompt.prompt}
]
for attachment in prev_response.attachments:
url = attachment.url
if not url:
base64_image = attachment.base64_content()
url = f"data:{attachment.resolve_type()};base64,{base64_image}"
attachment_message.append(
{"type": "image_url", "image_url": {"url": url}}
)
messages.append({"role": "user", "content": attachment_message})
else:
messages.append(
{"role": "user", "content": prev_response.prompt.prompt}
)
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
if not prompt.attachments:
messages.append({"role": "user", "content": prompt.prompt})
else:
vision_message = [{"type": "text", "text": prompt.prompt}]
attachment_message = [{"type": "text", "text": prompt.prompt}]
for attachment in prompt.attachments:
url = attachment.url
if not url:
base64_image = attachment.base64_content()
url = f"data:{attachment.resolve_type()};base64,{base64_image}"
vision_message.append({"type": "image_url", "image_url": {"url": url}})
messages.append({"role": "user", "content": vision_message})
attachment_message.append(
{"type": "image_url", "image_url": {"url": url}}
)
messages.append({"role": "user", "content": attachment_message})
response._prompt_json = {"messages": messages}
kwargs = self.build_kwargs(prompt)

View file

@ -23,17 +23,20 @@ class Attachment:
path: Optional[str] = None
url: Optional[str] = None
content: Optional[bytes] = None
_id: Optional[str] = None
def hash_id(self):
def id(self):
# Hash of the binary content, or of '{"url": "https://..."}' for URL attachments
if self.content:
return hashlib.sha256(self.content).hexdigest()
elif self.path:
return hashlib.sha256(open(self.path, "rb").read()).hexdigest()
else:
return hashlib.sha256(
json.dumps({"url": self.url}).encode("utf-8")
).hexdigest()
if self._id is None:
if self.content:
self._id = hashlib.sha256(self.content).hexdigest()
elif self.path:
self._id = hashlib.sha256(open(self.path, "rb").read()).hexdigest()
else:
self._id = hashlib.sha256(
json.dumps({"url": self.url}).encode("utf-8")
).hexdigest()
return self._id
def resolve_type(self):
if self.type:
@ -58,6 +61,16 @@ class Attachment:
content = response.content
return base64.b64encode(content).decode("utf-8")
@classmethod
def from_row(cls, row):
return cls(
_id=row["id"],
type=row["type"],
path=row["path"],
url=row["url"],
content=row["content"],
)
@dataclass
class Prompt:
@ -211,7 +224,7 @@ class Response(ABC):
db["responses"].insert(response)
# Persist any attachments - loop through with index
for index, attachment in enumerate(self.prompt.attachments):
attachment_id = attachment.hash_id()
attachment_id = attachment.id()
db["attachments"].insert(
{
"id": attachment_id,
@ -255,7 +268,7 @@ class Response(ABC):
return response_obj
@classmethod
def from_row(cls, row):
def from_row(cls, db, row):
from llm import get_model
model = get_model(row["model"])
@ -276,6 +289,19 @@ class Response(ABC):
response.response_json = json.loads(row["response_json"] or "null")
response._done = True
response._chunks = [row["response"]]
# Attachments
response.attachments = [
Attachment.from_row(arow)
for arow in db.query(
"""
select attachments.* from attachments
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
where prompt_attachments.response_id = ?
order by prompt_attachments."order"
""",
[row["id"]],
)
]
return response
def __repr__(self):