mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-30 18:04:45 +00:00
First working prototype of new attachments feature, refs #587
This commit is contained in:
parent
a466ddf3cd
commit
6df00f92ff
5 changed files with 193 additions and 14 deletions
|
|
@ -4,6 +4,7 @@ from .errors import (
|
|||
NeedsKeyException,
|
||||
)
|
||||
from .models import (
|
||||
Attachment,
|
||||
Conversation,
|
||||
Model,
|
||||
ModelWithAliases,
|
||||
|
|
|
|||
84
llm/cli.py
84
llm/cli.py
|
|
@ -4,6 +4,7 @@ from dataclasses import asdict
|
|||
import io
|
||||
import json
|
||||
from llm import (
|
||||
Attachment,
|
||||
Collection,
|
||||
Conversation,
|
||||
Response,
|
||||
|
|
@ -30,7 +31,9 @@ from llm import (
|
|||
from .migrations import migrate
|
||||
from .plugins import pm
|
||||
import base64
|
||||
import httpx
|
||||
import pathlib
|
||||
import puremagic
|
||||
import pydantic
|
||||
import readline
|
||||
from runpy import run_module
|
||||
|
|
@ -48,6 +51,54 @@ warnings.simplefilter("ignore", ResourceWarning)
|
|||
DEFAULT_TEMPLATE = "prompt: "
|
||||
|
||||
|
||||
class AttachmentType(click.ParamType):
|
||||
name = "attachment"
|
||||
|
||||
def convert(self, value, param, ctx):
|
||||
if value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
# Try to guess type
|
||||
try:
|
||||
mimetype = puremagic.from_string(content, mime=True)
|
||||
except puremagic.PureError:
|
||||
raise click.BadParameter("Could not determine mimetype of stdin")
|
||||
return Attachment(mimetype, None, None, content)
|
||||
if "://" in value:
|
||||
# Confirm URL exists and try to guess type
|
||||
try:
|
||||
response = httpx.head(value)
|
||||
response.raise_for_status()
|
||||
mimetype = response.headers.get("content-type")
|
||||
except httpx.HTTPError as ex:
|
||||
raise click.BadParameter(str(ex))
|
||||
return Attachment(mimetype, None, value, None)
|
||||
# Check that the file exists
|
||||
path = pathlib.Path(value)
|
||||
if not path.exists():
|
||||
self.fail(f"File {value} does not exist", param, ctx)
|
||||
# Try to guess type
|
||||
mimetype = puremagic.from_file(str(path), mime=True)
|
||||
return Attachment(mimetype, str(path), None, None)
|
||||
|
||||
|
||||
def attachment_types_callback(ctx, param, values):
|
||||
collected = []
|
||||
for value, mimetype in values:
|
||||
if "://" in value:
|
||||
attachment = Attachment(mimetype, None, value, None)
|
||||
elif value == "-":
|
||||
content = sys.stdin.buffer.read()
|
||||
attachment = Attachment(mimetype, None, None, content)
|
||||
else:
|
||||
# Look for file
|
||||
path = pathlib.Path(value)
|
||||
if not path.exists():
|
||||
raise click.BadParameter(f"File {value} does not exist")
|
||||
attachment = Attachment(mimetype, str(path), None, None)
|
||||
collected.append(attachment)
|
||||
return collected
|
||||
|
||||
|
||||
def _validate_metadata_json(ctx, param, value):
|
||||
if value is None:
|
||||
return value
|
||||
|
|
@ -88,6 +139,23 @@ def cli():
|
|||
@click.argument("prompt", required=False)
|
||||
@click.option("-s", "--system", help="System prompt to use")
|
||||
@click.option("model_id", "-m", "--model", help="Model to use")
|
||||
@click.option(
|
||||
"attachments",
|
||||
"-a",
|
||||
"--attachment",
|
||||
type=AttachmentType(),
|
||||
multiple=True,
|
||||
help="Attachment path or URL or -",
|
||||
)
|
||||
@click.option(
|
||||
"attachment_types",
|
||||
"--at",
|
||||
"--attachment-type",
|
||||
type=(str, str),
|
||||
multiple=True,
|
||||
callback=attachment_types_callback,
|
||||
help="Attachment with explicit mimetype",
|
||||
)
|
||||
@click.option(
|
||||
"options",
|
||||
"-o",
|
||||
|
|
@ -127,6 +195,8 @@ def prompt(
|
|||
prompt,
|
||||
system,
|
||||
model_id,
|
||||
attachments,
|
||||
attachment_types,
|
||||
options,
|
||||
template,
|
||||
param,
|
||||
|
|
@ -142,6 +212,14 @@ def prompt(
|
|||
Execute a prompt
|
||||
|
||||
Documentation: https://llm.datasette.io/en/stable/usage.html
|
||||
|
||||
Examples:
|
||||
|
||||
\b
|
||||
llm 'Capital of France?'
|
||||
llm 'Capital of France?' -m gpt-4o
|
||||
llm 'Capital of France?' -s 'answer in Spanish'
|
||||
llm 'Extract text from this image' -a image.jpg
|
||||
"""
|
||||
if log and no_log:
|
||||
raise click.ClickException("--log and --no-log are mutually exclusive")
|
||||
|
|
@ -262,6 +340,8 @@ def prompt(
|
|||
except pydantic.ValidationError as ex:
|
||||
raise click.ClickException(render_errors(ex.errors()))
|
||||
|
||||
resolved_attachments = [*attachments, *attachment_types]
|
||||
|
||||
should_stream = model.can_stream and not no_stream
|
||||
if not should_stream:
|
||||
validated_options["stream"] = False
|
||||
|
|
@ -273,7 +353,9 @@ def prompt(
|
|||
prompt_method = conversation.prompt
|
||||
|
||||
try:
|
||||
response = prompt_method(prompt, system, **validated_options)
|
||||
response = prompt_method(
|
||||
prompt, *resolved_attachments, system=system, **validated_options
|
||||
)
|
||||
if should_stream:
|
||||
for chunk in response:
|
||||
print(chunk, end="")
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ def register_models(register):
|
|||
register(Chat("gpt-4-turbo-2024-04-09"))
|
||||
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
|
||||
# GPT-4o
|
||||
register(Chat("gpt-4o"), aliases=("4o",))
|
||||
register(Chat("gpt-4o-mini"), aliases=("4o-mini",))
|
||||
register(Chat("gpt-4o", vision=True), aliases=("4o",))
|
||||
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
|
||||
# o1
|
||||
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
|
||||
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
|
||||
|
|
@ -271,6 +271,7 @@ class Chat(Model):
|
|||
api_engine=None,
|
||||
headers=None,
|
||||
can_stream=True,
|
||||
vision=False,
|
||||
allows_system_prompt=True,
|
||||
):
|
||||
self.model_id = model_id
|
||||
|
|
@ -282,8 +283,17 @@ class Chat(Model):
|
|||
self.api_engine = api_engine
|
||||
self.headers = headers
|
||||
self.can_stream = can_stream
|
||||
self.vision = vision
|
||||
self.allows_system_prompt = allows_system_prompt
|
||||
|
||||
if vision:
|
||||
self.attachment_types = {
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return "OpenAI Chat: {}".format(self.model_id)
|
||||
|
||||
|
|
@ -308,7 +318,18 @@ class Chat(Model):
|
|||
messages.append({"role": "assistant", "content": prev_response.text()})
|
||||
if prompt.system and prompt.system != current_system:
|
||||
messages.append({"role": "system", "content": prompt.system})
|
||||
messages.append({"role": "user", "content": prompt.prompt})
|
||||
if not prompt.attachments:
|
||||
messages.append({"role": "user", "content": prompt.prompt})
|
||||
else:
|
||||
vision_message = [{"type": "text", "text": prompt.prompt}]
|
||||
for attachment in prompt.attachments:
|
||||
url = attachment.url
|
||||
if not url:
|
||||
base64_image = attachment.base64_content()
|
||||
url = f"data:{attachment.resolve_type()};base64,{base64_image}"
|
||||
vision_message.append({"type": "image_url", "image_url": {"url": url}})
|
||||
messages.append({"role": "user", "content": vision_message})
|
||||
|
||||
response._prompt_json = {"messages": messages}
|
||||
kwargs = self.build_kwargs(prompt)
|
||||
client = self.get_client()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import base64
|
||||
from dataclasses import dataclass, field
|
||||
import datetime
|
||||
from .errors import NeedsKeyException
|
||||
import httpx
|
||||
from itertools import islice
|
||||
import puremagic
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
|
||||
|
|
@ -13,17 +16,52 @@ from ulid import ULID
|
|||
CONVERSATION_NAME_LENGTH = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attachment:
|
||||
type: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
content: Optional[bytes] = None
|
||||
|
||||
def resolve_type(self):
|
||||
if self.type:
|
||||
return self.type
|
||||
# Derive it from path or url or content
|
||||
if self.path:
|
||||
return puremagic.from_file(self.path, mime=True)
|
||||
if self.url:
|
||||
return puremagic.from_url(self.url, mime=True)
|
||||
if self.content:
|
||||
return puremagic.from_string(self.content, mime=True)
|
||||
raise ValueError("Attachment has no type and no content to derive it from")
|
||||
|
||||
def base64_content(self):
|
||||
content = self.content
|
||||
if not content:
|
||||
if self.path:
|
||||
content = open(self.path, "rb").read()
|
||||
elif self.url:
|
||||
response = httpx.get(self.url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prompt:
|
||||
prompt: str
|
||||
model: "Model"
|
||||
system: Optional[str]
|
||||
prompt_json: Optional[str]
|
||||
options: "Options"
|
||||
attachments: Optional[List[Attachment]] = field(default_factory=list)
|
||||
system: Optional[str] = None
|
||||
prompt_json: Optional[str] = None
|
||||
options: "Options" = field(default_factory=dict)
|
||||
|
||||
def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
|
||||
def __init__(
|
||||
self, prompt, model, attachments, system=None, prompt_json=None, options=None
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
self.attachments = list(attachments)
|
||||
self.system = system
|
||||
self.prompt_json = prompt_json
|
||||
self.options = options or {}
|
||||
|
|
@ -39,6 +77,7 @@ class Conversation:
|
|||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
*attachments: Attachment,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options
|
||||
|
|
@ -46,8 +85,9 @@ class Conversation:
|
|||
return Response(
|
||||
Prompt(
|
||||
prompt,
|
||||
system=system,
|
||||
model=self.model,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
options=self.model.Options(**options),
|
||||
),
|
||||
self.model,
|
||||
|
|
@ -158,14 +198,22 @@ class Response(ABC):
|
|||
db["responses"].insert(response)
|
||||
|
||||
@classmethod
|
||||
def fake(cls, model: "Model", prompt: str, system: str, response: str):
|
||||
def fake(
|
||||
cls,
|
||||
model: "Model",
|
||||
prompt: str,
|
||||
*attachments: List[Attachment],
|
||||
system: str,
|
||||
response: str
|
||||
):
|
||||
"Utility method to help with writing tests"
|
||||
response_obj = cls(
|
||||
model=model,
|
||||
prompt=Prompt(
|
||||
prompt,
|
||||
system=system,
|
||||
model=model,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
|
@ -183,8 +231,9 @@ class Response(ABC):
|
|||
model=model,
|
||||
prompt=Prompt(
|
||||
prompt=row["prompt"],
|
||||
system=row["system"],
|
||||
model=model,
|
||||
attachments=[],
|
||||
system=row["system"],
|
||||
options=model.Options(**json.loads(row["options_json"])),
|
||||
),
|
||||
stream=False,
|
||||
|
|
@ -242,10 +291,15 @@ class _get_key_mixin:
|
|||
|
||||
class Model(ABC, _get_key_mixin):
|
||||
model_id: str
|
||||
|
||||
# API key handling
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
|
||||
# Model characteristics
|
||||
can_stream: bool = False
|
||||
attachment_types = set()
|
||||
|
||||
class Options(_Options):
|
||||
pass
|
||||
|
|
@ -269,13 +323,33 @@ class Model(ABC, _get_key_mixin):
|
|||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
prompt: str,
|
||||
*attachments: Attachment,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options
|
||||
):
|
||||
# Validate attachments
|
||||
if attachments and not self.attachment_types:
|
||||
raise ValueError(
|
||||
"This model does not support attachments, but some were provided"
|
||||
)
|
||||
for attachment in attachments:
|
||||
attachment_type = attachment.resolve_type()
|
||||
if attachment_type not in self.attachment_types:
|
||||
raise ValueError(
|
||||
"This model does not support attachments of type '{}', only {}".format(
|
||||
attachment_type, ", ".join(self.attachment_types)
|
||||
)
|
||||
)
|
||||
return self.response(
|
||||
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
|
||||
Prompt(
|
||||
prompt,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
model=self,
|
||||
options=self.Options(**options),
|
||||
),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -48,6 +48,7 @@ setup(
|
|||
"setuptools",
|
||||
"pip",
|
||||
"pyreadline3; sys_platform == 'win32'",
|
||||
"puremagic",
|
||||
],
|
||||
extras_require={
|
||||
"test": [
|
||||
|
|
|
|||
Loading…
Reference in a new issue