Special case treat audio/wave as audio/wav, closes #603

This commit is contained in:
Simon Willison 2024-11-07 17:13:54 -08:00
parent febbc04fb6
commit 5d1d723d4b
5 changed files with 52 additions and 19 deletions

View file

@ -30,10 +30,10 @@ from llm import (
from .migrations import migrate
from .plugins import pm
from .utils import mimetype_from_path, mimetype_from_string
import base64
import httpx
import pathlib
import puremagic
import pydantic
import readline
from runpy import run_module
@ -58,9 +58,8 @@ class AttachmentType(click.ParamType):
if value == "-":
content = sys.stdin.buffer.read()
# Try to guess type
try:
mimetype = puremagic.from_string(content, mime=True)
except puremagic.PureError:
mimetype = mimetype_from_string(content)
if mimetype is None:
raise click.BadParameter("Could not determine mimetype of stdin")
return Attachment(type=mimetype, path=None, url=None, content=content)
if "://" in value:
@ -78,7 +77,9 @@ class AttachmentType(click.ParamType):
self.fail(f"File {value} does not exist", param, ctx)
path = path.resolve()
# Try to guess type
mimetype = puremagic.from_file(str(path), mime=True)
mimetype = mimetype_from_path(str(path))
if mimetype is None:
raise click.BadParameter(f"Could not determine mimetype of {value}")
return Attachment(type=mimetype, path=str(path), url=None, content=None)

View file

@ -5,10 +5,10 @@ from .errors import NeedsKeyException
import hashlib
import httpx
from itertools import islice
import puremagic
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from .utils import mimetype_from_path, mimetype_from_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
@ -43,13 +43,13 @@ class Attachment:
return self.type
# Derive it from path or url or content
if self.path:
return puremagic.from_file(self.path, mime=True)
return mimetype_from_path(self.path)
if self.url:
response = httpx.head(self.url)
response.raise_for_status()
return response.headers.get("content-type")
if self.content:
return puremagic.from_string(self.content, mime=True)
return mimetype_from_string(self.content)
raise ValueError("Attachment has no type and no content to derive it from")
def content_bytes(self):

View file

@ -1,8 +1,29 @@
import click
import httpx
import json
import puremagic
import textwrap
from typing import List, Dict
from typing import List, Dict, Optional
MIME_TYPE_FIXES = {
"audio/wave": "audio/wav",
}
def mimetype_from_string(content) -> Optional[str]:
try:
type_ = puremagic.from_string(content, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None
def mimetype_from_path(path) -> Optional[str]:
try:
type_ = puremagic.from_file(path, mime=True)
return MIME_TYPE_FIXES.get(type_, type_)
except puremagic.PureError:
return None
def dicts_to_table_string(

View file

@ -49,7 +49,7 @@ def env_setup(monkeypatch, user_path):
class MockModel(llm.Model):
model_id = "mock"
attachment_types = {"image/png"}
attachment_types = {"image/png", "audio/wav"}
class Options(llm.Options):
max_tokens: Optional[int] = Field(

View file

@ -1,6 +1,8 @@
from click.testing import CliRunner
from unittest.mock import ANY
import llm
from llm import cli
import pytest
TINY_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xa6\x00\x00\x01\x1a"
@ -12,20 +14,29 @@ TINY_PNG = (
b"\x82"
)
TINY_WAV = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00D\xac\x00\x00"
def test_prompt_image(mock_model, logs_db):
@pytest.mark.parametrize(
"attachment_type,attachment_content",
[
("image/png", TINY_PNG),
("audio/wav", TINY_WAV),
],
)
def test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content):
runner = CliRunner()
mock_model.enqueue(["two boxes"])
result = runner.invoke(
llm.cli.cli,
["prompt", "-m", "mock", "describe image", "-a", "-"],
input=TINY_PNG,
cli.cli,
["prompt", "-m", "mock", "describe file", "-a", "-"],
input=attachment_content,
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert result.output == "two boxes\n"
assert mock_model.history[0][0].attachments[0] == llm.Attachment(
type="image/png", path=None, url=None, content=TINY_PNG, _id=ANY
type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY
)
# Check it was logged correctly
@ -33,15 +44,15 @@ def test_prompt_image(mock_model, logs_db):
assert len(conversations) == 1
conversation = conversations[0]
assert conversation["model"] == "mock"
assert conversation["name"] == "describe image"
assert conversation["name"] == "describe file"
response = list(logs_db["responses"].rows)[0]
attachment = list(logs_db["attachments"].rows)[0]
assert attachment == {
"id": ANY,
"type": "image/png",
"type": attachment_type,
"path": None,
"url": None,
"content": TINY_PNG,
"content": attachment_content,
}
prompt_attachment = list(logs_db["prompt_attachments"].rows)[0]
assert prompt_attachment["attachment_id"] == attachment["id"]