mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Special case treat audio/wave as audio/wav, closes #603
This commit is contained in:
parent
febbc04fb6
commit
5d1d723d4b
5 changed files with 52 additions and 19 deletions
11
llm/cli.py
11
llm/cli.py
|
|
@ -30,10 +30,10 @@ from llm import (
|
|||
|
||||
from .migrations import migrate
|
||||
from .plugins import pm
|
||||
from .utils import mimetype_from_path, mimetype_from_string
|
||||
import base64
|
||||
import httpx
|
||||
import pathlib
|
||||
import puremagic
|
||||
import pydantic
|
||||
import readline
|
||||
from runpy import run_module
|
||||
|
|
@ -58,9 +58,8 @@ class AttachmentType(click.ParamType):
|
|||
if value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
# Try to guess type
|
||||
try:
|
||||
mimetype = puremagic.from_string(content, mime=True)
|
||||
except puremagic.PureError:
|
||||
mimetype = mimetype_from_string(content)
|
||||
if mimetype is None:
|
||||
raise click.BadParameter("Could not determine mimetype of stdin")
|
||||
return Attachment(type=mimetype, path=None, url=None, content=content)
|
||||
if "://" in value:
|
||||
|
|
@ -78,7 +77,9 @@ class AttachmentType(click.ParamType):
|
|||
self.fail(f"File {value} does not exist", param, ctx)
|
||||
path = path.resolve()
|
||||
# Try to guess type
|
||||
mimetype = puremagic.from_file(str(path), mime=True)
|
||||
mimetype = mimetype_from_path(str(path))
|
||||
if mimetype is None:
|
||||
raise click.BadParameter(f"Could not determine mimetype of {value}")
|
||||
return Attachment(type=mimetype, path=str(path), url=None, content=None)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from .errors import NeedsKeyException
|
|||
import hashlib
|
||||
import httpx
|
||||
from itertools import islice
|
||||
import puremagic
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
|
||||
from .utils import mimetype_from_path, mimetype_from_string
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -43,13 +43,13 @@ class Attachment:
|
|||
return self.type
|
||||
# Derive it from path or url or content
|
||||
if self.path:
|
||||
return puremagic.from_file(self.path, mime=True)
|
||||
return mimetype_from_path(self.path)
|
||||
if self.url:
|
||||
response = httpx.head(self.url)
|
||||
response.raise_for_status()
|
||||
return response.headers.get("content-type")
|
||||
if self.content:
|
||||
return puremagic.from_string(self.content, mime=True)
|
||||
return mimetype_from_string(self.content)
|
||||
raise ValueError("Attachment has no type and no content to derive it from")
|
||||
|
||||
def content_bytes(self):
|
||||
|
|
|
|||
23
llm/utils.py
23
llm/utils.py
|
|
@ -1,8 +1,29 @@
|
|||
import click
|
||||
import httpx
|
||||
import json
|
||||
import puremagic
|
||||
import textwrap
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
MIME_TYPE_FIXES = {
|
||||
"audio/wave": "audio/wav",
|
||||
}
|
||||
|
||||
|
||||
def mimetype_from_string(content) -> Optional[str]:
|
||||
try:
|
||||
type_ = puremagic.from_string(content, mime=True)
|
||||
return MIME_TYPE_FIXES.get(type_, type_)
|
||||
except puremagic.PureError:
|
||||
return None
|
||||
|
||||
|
||||
def mimetype_from_path(path) -> Optional[str]:
|
||||
try:
|
||||
type_ = puremagic.from_file(path, mime=True)
|
||||
return MIME_TYPE_FIXES.get(type_, type_)
|
||||
except puremagic.PureError:
|
||||
return None
|
||||
|
||||
|
||||
def dicts_to_table_string(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def env_setup(monkeypatch, user_path):
|
|||
|
||||
class MockModel(llm.Model):
|
||||
model_id = "mock"
|
||||
attachment_types = {"image/png"}
|
||||
attachment_types = {"image/png", "audio/wav"}
|
||||
|
||||
class Options(llm.Options):
|
||||
max_tokens: Optional[int] = Field(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from click.testing import CliRunner
|
||||
from unittest.mock import ANY
|
||||
import llm
|
||||
from llm import cli
|
||||
import pytest
|
||||
|
||||
TINY_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xa6\x00\x00\x01\x1a"
|
||||
|
|
@ -12,20 +14,29 @@ TINY_PNG = (
|
|||
b"\x82"
|
||||
)
|
||||
|
||||
TINY_WAV = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00D\xac\x00\x00"
|
||||
|
||||
def test_prompt_image(mock_model, logs_db):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attachment_type,attachment_content",
|
||||
[
|
||||
("image/png", TINY_PNG),
|
||||
("audio/wav", TINY_WAV),
|
||||
],
|
||||
)
|
||||
def test_prompt_attachment(mock_model, logs_db, attachment_type, attachment_content):
|
||||
runner = CliRunner()
|
||||
mock_model.enqueue(["two boxes"])
|
||||
result = runner.invoke(
|
||||
llm.cli.cli,
|
||||
["prompt", "-m", "mock", "describe image", "-a", "-"],
|
||||
input=TINY_PNG,
|
||||
cli.cli,
|
||||
["prompt", "-m", "mock", "describe file", "-a", "-"],
|
||||
input=attachment_content,
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.exit_code == 0, result.output
|
||||
assert result.output == "two boxes\n"
|
||||
assert mock_model.history[0][0].attachments[0] == llm.Attachment(
|
||||
type="image/png", path=None, url=None, content=TINY_PNG, _id=ANY
|
||||
type=attachment_type, path=None, url=None, content=attachment_content, _id=ANY
|
||||
)
|
||||
|
||||
# Check it was logged correctly
|
||||
|
|
@ -33,15 +44,15 @@ def test_prompt_image(mock_model, logs_db):
|
|||
assert len(conversations) == 1
|
||||
conversation = conversations[0]
|
||||
assert conversation["model"] == "mock"
|
||||
assert conversation["name"] == "describe image"
|
||||
assert conversation["name"] == "describe file"
|
||||
response = list(logs_db["responses"].rows)[0]
|
||||
attachment = list(logs_db["attachments"].rows)[0]
|
||||
assert attachment == {
|
||||
"id": ANY,
|
||||
"type": "image/png",
|
||||
"type": attachment_type,
|
||||
"path": None,
|
||||
"url": None,
|
||||
"content": TINY_PNG,
|
||||
"content": attachment_content,
|
||||
}
|
||||
prompt_attachment = list(logs_db["prompt_attachments"].rows)[0]
|
||||
assert prompt_attachment["attachment_id"] == attachment["id"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue