llm.get_async_model(), llm.AsyncModel base class and OpenAI async models (#613)

- https://github.com/simonw/llm/issues/507#issuecomment-2458639308

* register_model is now async aware

Refs https://github.com/simonw/llm/issues/507#issuecomment-2458658134

* Refactor Chat and AsyncChat to use _Shared base class

Refs https://github.com/simonw/llm/issues/507#issuecomment-2458692338

* fixed function name

* Fix for infinite loop

* Applied Black

* Ran cog

* Applied Black

* Add Response.from_row() classmethod back again

It does not matter that this is a blocking call, since it is a classmethod

* Made mypy happy with llm/models.py

* mypy fixes for openai_models.py

I am unhappy with this, had to duplicate some code.

* First test for AsyncModel

* Still have not quite got this working

* Fix for not loading plugins during tests, refs #626

* audio/wav not audio/wave, refs #603

* Black and mypy and ruff all happy

* Refactor to avoid generics

* Removed obsolete response() method

* Support text = await async_mock_model.prompt("hello")

* Initial docs for llm.get_async_model() and await model.prompt()

Refs #507

* Initial async model plugin creation docs

* duration_ms ANY to pass test

* llm models --async option

Refs https://github.com/simonw/llm/pull/613#issuecomment-2474724406

* Removed obsolete TypeVars

* Expanded register_models() docs for async

* await model.prompt() now returns AsyncResponse

Refs https://github.com/simonw/llm/pull/613#issuecomment-2475157822

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Simon Willison 2024-11-13 17:51:00 -08:00 committed by GitHub
parent 5a984d0c87
commit ba75c674cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 688 additions and 219 deletions

View file

@ -121,6 +121,7 @@ Options:
--cid, --conversation TEXT Continue the conversation with the given ID.
--key TEXT API key to use
--save TEXT Save prompt with this template name
--async Run prompt asynchronously
--help Show this message and exit.
```
@ -322,6 +323,7 @@ Usage: llm models list [OPTIONS]
Options:
--options Show options for each model, if available
--async List async models
--help Show this message and exit.
```

View file

@ -5,13 +5,64 @@ The {ref}`model plugin tutorial <tutorial-model-plugin>` covers the basics of de
This document covers more advanced topics.
(advanced-model-plugins-async)=
## Async models
Plugins can optionally provide an asynchronous version of their model, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). This is particularly useful for remote models accessible by an HTTP API.
The async version of a model subclasses `llm.AsyncModel` instead of `llm.Model`. It must implement an `async def execute()` async generator method instead of `def execute()`.
This example shows a subset of the OpenAI default plugin illustrating how this method might work:
```python
from typing import AsyncGenerator
import llm
class MyAsyncModel(llm.AsyncModel):
# This cn duplicate the model_id of the sync model:
model_id = "my-model-id"
async def execute(
self, prompt, stream, response, conversation=None
) -> AsyncGenerator[str, None]:
if stream:
completion = await client.chat.completions.create(
model=self.model_id,
messages=messages,
stream=True,
)
async for chunk in completion:
yield chunk.choices[0].delta.content
else:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
)
yield completion.choices[0].message.content
```
This async model instance should then be passed to the `register()` method in the `register_models()` plugin hook:
```python
@hookimpl
def register_models(register):
register(
MyModel(), MyAsyncModel(), aliases=("my-model-aliases",)
)
```
(advanced-model-plugins-attachments)=
## Attachments for multi-modal models
Models such as GPT-4o, Claude 3.5 Sonnet and Google's Gemini 1.5 are multi-modal: they accept input in the form of images and maybe even audio, video and other formats.
LLM calls these **attachments**. Models can specify the types of attachments they accept and then implement special code in the `.execute()` method to handle them.
See {ref}`the Python attachments documentation <python-api-attachments>` for details on using attachments in the Python API.
### Specifying attachment types
A `Model` subclass can list the types of attachments it accepts by defining a `attachment_types` class attribute:

View file

@ -42,5 +42,20 @@ class HelloWorld(llm.Model):
def execute(self, prompt, stream, response):
return ["hello world"]
```
If your model includes an async version, you can register that too:
```python
class AsyncHelloWorld(llm.AsyncModel):
model_id = "helloworld"
async def execute(self, prompt, stream, response):
return ["hello world"]
@llm.hookimpl
def register_models(register):
register(HelloWorld(), AsyncHelloWorld(), aliases=("hw",))
```
This demonstrates how to register a model with both sync and async versions, and how to specify an alias for that model.
The {ref}`model plugin tutorial <tutorial-model-plugin>` describes how to use this hook in detail. Asynchronous models {ref}`are described here <advanced-model-plugins-async>`.
{ref}`tutorial-model-plugin` describes how to use this hook in detail.

View file

@ -99,7 +99,7 @@ print(response.text())
```
Some models do not use API keys at all.
## Streaming responses
### Streaming responses
For models that support it you can stream responses as they are generated, like this:
@ -112,6 +112,34 @@ The `response.text()` method described earlier does this for you - it runs throu
If a response has been evaluated, `response.text()` will continue to return the same string.
(python-api-async)=
## Async models
Some plugins provide async versions of their supported models, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html).
To use an async model, use the `llm.get_async_model()` function instead of `llm.get_model()`:
```python
import llm
model = llm.get_async_model("gpt-4o")
```
You can then run a prompt using `await model.prompt(...)`:
```python
response = await model.prompt(
"Five surprising names for a pet pelican"
)
print(await response.text())
```
Or use `async for chunk in ...` to stream the response as it is generated:
```python
async for chunk in model.prompt(
"Five surprising names for a pet pelican"
):
print(chunk, end="", flush=True)
```
## Conversations
LLM supports *conversations*, where you ask follow-up questions of a model as part of an ongoing conversation.

View file

@ -4,6 +4,8 @@ from .errors import (
NeedsKeyException,
)
from .models import (
AsyncModel,
AsyncResponse,
Attachment,
Conversation,
Model,
@ -26,9 +28,11 @@ import struct
__all__ = [
"hookimpl",
"get_async_model",
"get_model",
"get_key",
"user_dir",
"AsyncResponse",
"Attachment",
"Collection",
"Conversation",
@ -74,11 +78,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)
def register(model, aliases=None):
def register(model, async_model=None, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(ModelWithAliases(model, alias_list))
model_aliases.append(ModelWithAliases(model, async_model, alias_list))
load_plugins()
pm.hook.register_models(register=register)
@ -137,12 +141,25 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
return model_aliases
def get_async_model_aliases() -> Dict[str, AsyncModel]:
async_model_aliases = {}
for model_with_aliases in get_models_with_aliases():
if model_with_aliases.async_model:
for alias in model_with_aliases.aliases:
async_model_aliases[alias] = model_with_aliases.async_model
async_model_aliases[model_with_aliases.model.model_id] = (
model_with_aliases.async_model
)
return async_model_aliases
def get_model_aliases() -> Dict[str, Model]:
model_aliases = {}
for model_with_aliases in get_models_with_aliases():
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
if model_with_aliases.model:
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases
@ -150,13 +167,42 @@ class UnknownModelError(KeyError):
pass
def get_model(name: Optional[str] = None) -> Model:
def get_async_model(name: Optional[str] = None) -> AsyncModel:
aliases = get_async_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
# Does a sync model exist?
sync_model = None
try:
sync_model = get_model(name, _skip_async=True)
except UnknownModelError:
pass
if sync_model:
raise UnknownModelError("Unknown async model (sync model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)
def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model:
aliases = get_model_aliases()
name = name or get_default_model()
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + name)
# Does an async model exist?
if _skip_async:
raise UnknownModelError("Unknown model: " + name)
async_model = None
try:
async_model = get_async_model(name)
except UnknownModelError:
pass
if async_model:
raise UnknownModelError("Unknown model (async model exists): " + name)
else:
raise UnknownModelError("Unknown model: " + name)
def get_key(

View file

@ -1,3 +1,4 @@
import asyncio
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
@ -11,6 +12,7 @@ from llm import (
Template,
UnknownModelError,
encode,
get_async_model,
get_default_model,
get_default_embedding_model,
get_embedding_models_with_aliases,
@ -199,6 +201,7 @@ def cli():
)
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
def prompt(
prompt,
system,
@ -215,6 +218,7 @@ def prompt(
conversation_id,
key,
save,
async_,
):
"""
Execute a prompt
@ -337,9 +341,12 @@ def prompt(
# Now resolve the model
try:
model = model_aliases[model_id]
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
if async_:
model = get_async_model(model_id)
else:
model = get_model(model_id)
except UnknownModelError as ex:
raise click.ClickException(ex)
# Provide the API key, if one is needed and has been provided
if model.needs_key:
@ -375,21 +382,48 @@ def prompt(
prompt_method = conversation.prompt
try:
response = prompt_method(
prompt, attachments=resolved_attachments, system=system, **validated_options
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
if async_:
async def inner():
if should_stream:
async for chunk in prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
):
print(chunk, end="")
sys.stdout.flush()
print("")
else:
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
print(await response.text())
asyncio.run(inner())
else:
print(response.text())
response = prompt_method(
prompt,
attachments=resolved_attachments,
system=system,
**validated_options,
)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
print(response.text())
except Exception as ex:
raise click.ClickException(str(ex))
# Log to the database
if (logs_on() or log) and not no_log:
if (logs_on() or log) and not no_log and not async_:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
@ -981,14 +1015,19 @@ _type_lookup = {
@click.option(
"--options", is_flag=True, help="Show options for each model, if available"
)
def models_list(options):
@click.option("async_", "--async", is_flag=True, help="List async models")
def models_list(options, async_):
"List available models"
models_that_have_shown_options = set()
for model_with_aliases in get_models_with_aliases():
if async_ and not model_with_aliases.async_model:
continue
extra = ""
if model_with_aliases.aliases:
extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases))
model = model_with_aliases.model
model = (
model_with_aliases.model if not async_ else model_with_aliases.async_model
)
output = str(model) + extra
if options and model.Options.schema()["properties"]:
output += "\n Options:"

View file

@ -1,4 +1,4 @@
from llm import EmbeddingModel, Model, hookimpl
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
import click
@ -16,7 +16,7 @@ except ImportError:
from pydantic.fields import Field
from pydantic.class_validators import validator as field_validator # type: ignore [no-redef]
from typing import List, Iterable, Iterator, Optional, Union
from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union
import json
import yaml
@ -24,22 +24,47 @@ import yaml
@hookimpl
def register_models(register):
# GPT-4o
register(Chat("gpt-4o", vision=True), aliases=("4o",))
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
register(Chat("gpt-4o-audio-preview", audio=True))
register(
Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",)
)
register(
Chat("gpt-4o-mini", vision=True),
AsyncChat("gpt-4o-mini", vision=True),
aliases=("4o-mini",),
)
register(
Chat("gpt-4o-audio-preview", audio=True),
AsyncChat("gpt-4o-audio-preview", audio=True),
)
# 3.5 and 4
register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt"))
register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k"))
register(Chat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), aliases=("4-32k",))
register(
Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")
)
register(
Chat("gpt-3.5-turbo-16k"),
AsyncChat("gpt-3.5-turbo-16k"),
aliases=("chatgpt-16k", "3.5-16k"),
)
register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4"))
register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",))
# GPT-4 Turbo models
register(Chat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-2024-04-09"))
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview"))
register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview"))
register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09"))
register(
Chat("gpt-4-turbo"),
AsyncChat("gpt-4-turbo"),
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
)
# o1
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
register(
Chat("o1-preview", can_stream=False, allows_system_prompt=False),
AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False),
)
register(
Chat("o1-mini", can_stream=False, allows_system_prompt=False),
AsyncChat("o1-mini", can_stream=False, allows_system_prompt=False),
)
# The -instruct completion model
register(
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
@ -273,18 +298,7 @@ def _attachment(attachment):
}
class Chat(Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None
class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)
class _Shared:
def __init__(
self,
model_id,
@ -335,10 +349,8 @@ class Chat(Model):
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)
def execute(self, prompt, stream, response, conversation=None):
def build_messages(self, prompt, conversation):
messages = []
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
current_system = None
if conversation is not None:
for prev_response in conversation.responses:
@ -375,7 +387,60 @@ class Chat(Model):
for attachment in prompt.attachments:
attachment_message.append(_attachment(attachment))
messages.append({"role": "user", "content": attachment_message})
return messages
def get_client(self, async_=False):
kwargs = {}
if self.api_base:
kwargs["base_url"] = self.api_base
if self.api_type:
kwargs["api_type"] = self.api_type
if self.api_version:
kwargs["api_version"] = self.api_version
if self.api_engine:
kwargs["engine"] = self.api_engine
if self.needs_key:
kwargs["api_key"] = self.get_key()
else:
# OpenAI-compatible models don't need a key, but the
# openai client library requires one
kwargs["api_key"] = "DUMMY_KEY"
if self.headers:
kwargs["default_headers"] = self.headers
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
kwargs["http_client"] = logging_client()
if async_:
return openai.AsyncOpenAI(**kwargs)
else:
return openai.OpenAI(**kwargs)
def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
if stream:
kwargs["stream_options"] = {"include_usage": True}
return kwargs
class Chat(_Shared, Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None
class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)
def execute(self, prompt, stream, response, conversation=None):
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
@ -406,38 +471,53 @@ class Chat(Model):
yield completion.choices[0].message.content
response._prompt_json = redact_data({"messages": messages})
def get_client(self):
kwargs = {}
if self.api_base:
kwargs["base_url"] = self.api_base
if self.api_type:
kwargs["api_type"] = self.api_type
if self.api_version:
kwargs["api_version"] = self.api_version
if self.api_engine:
kwargs["engine"] = self.api_engine
if self.needs_key:
kwargs["api_key"] = self.get_key()
else:
# OpenAI-compatible models don't need a key, but the
# openai client library requires one
kwargs["api_key"] = "DUMMY_KEY"
if self.headers:
kwargs["default_headers"] = self.headers
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
kwargs["http_client"] = logging_client()
return openai.OpenAI(**kwargs)
def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
class AsyncChat(_Shared, AsyncModel):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
default_max_tokens = None
class Options(SharedOptions):
json_object: Optional[bool] = Field(
description="Output a valid JSON object {...}. Prompt must mention JSON.",
default=None,
)
async def execute(
self, prompt, stream, response, conversation=None
) -> AsyncGenerator[str, None]:
if prompt.system and not self.allows_system_prompt:
raise NotImplementedError("Model does not support system prompts")
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client(async_=True)
if stream:
kwargs["stream_options"] = {"include_usage": True}
return kwargs
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=True,
**kwargs,
)
chunks = []
async for chunk in completion:
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
else:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
messages=messages,
stream=False,
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
response._prompt_json = redact_data({"messages": messages})
class Completion(Chat):

View file

@ -7,7 +7,17 @@ import httpx
from itertools import islice
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
)
from .utils import mimetype_from_path, mimetype_from_string
from abc import ABC, abstractmethod
import json
@ -94,7 +104,7 @@ class Prompt:
attachments=None,
system=None,
prompt_json=None,
options=None
options=None,
):
self.prompt = prompt
self.model = model
@ -105,12 +115,25 @@ class Prompt:
@dataclass
class Conversation:
model: "Model"
class _BaseConversation:
model: "_BaseModel"
id: str = field(default_factory=lambda: str(ULID()).lower())
name: Optional[str] = None
responses: List["Response"] = field(default_factory=list)
responses: List["_BaseResponse"] = field(default_factory=list)
@classmethod
def from_row(cls, row):
from llm import get_model
return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
)
@dataclass
class Conversation(_BaseConversation):
def prompt(
self,
prompt: Optional[str],
@ -118,8 +141,8 @@ class Conversation:
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options
):
**options,
) -> "Response":
return Response(
Prompt(
prompt,
@ -133,24 +156,45 @@ class Conversation:
conversation=self,
)
@classmethod
def from_row(cls, row):
from llm import get_model
return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
@dataclass
class AsyncConversation(_BaseConversation):
def prompt(
self,
prompt: Optional[str],
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> "AsyncResponse":
return AsyncResponse(
Prompt(
prompt,
model=self.model,
attachments=attachments,
system=system,
options=self.model.Options(**options),
),
self.model,
stream,
conversation=self,
)
class Response(ABC):
class _BaseResponse:
"""Base response class shared between sync and async responses"""
prompt: "Prompt"
stream: bool
conversation: Optional["_BaseConversation"] = None
def __init__(
self,
prompt: Prompt,
model: "Model",
model: "_BaseModel",
stream: bool,
conversation: Optional[Conversation] = None,
conversation: Optional[_BaseConversation] = None,
):
self.prompt = prompt
self._prompt_json = None
@ -161,47 +205,46 @@ class Response(ABC):
self.response_json = None
self.conversation = conversation
self.attachments: List[Attachment] = []
self._start: Optional[float] = None
self._end: Optional[float] = None
self._start_utcnow: Optional[datetime.datetime] = None
def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
if self._done:
yield from self._chunks
for chunk in self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
):
yield chunk
self._chunks.append(chunk)
if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
@classmethod
def from_row(cls, db, row):
from llm import get_model
def _force(self):
if not self._done:
list(self)
model = get_model(row["model"])
def __str__(self) -> str:
return self.text()
def text(self) -> str:
self._force()
return "".join(self._chunks)
def json(self) -> Optional[Dict[str, Any]]:
self._force()
return self.response_json
def duration_ms(self) -> int:
self._force()
return int((self._end - self._start) * 1000)
def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat()
response = cls(
model=model,
prompt=Prompt(
prompt=row["prompt"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
)
response.id = row["id"]
response._prompt_json = json.loads(row["prompt_json"] or "null")
response.response_json = json.loads(row["response_json"] or "null")
response._done = True
response._chunks = [row["response"]]
# Attachments
response.attachments = [
Attachment.from_row(arow)
for arow in db.query(
"""
select attachments.* from attachments
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
where prompt_attachments.response_id = ?
order by prompt_attachments."order"
""",
[row["id"]],
)
]
return response
def log_to_db(self, db):
conversation = self.conversation
@ -257,14 +300,126 @@ class Response(ABC):
},
)
class Response(_BaseResponse):
model: "Model"
conversation: Optional["Conversation"] = None
def __str__(self) -> str:
return self.text()
def _force(self):
if not self._done:
list(self)
def text(self) -> str:
self._force()
return "".join(self._chunks)
def json(self) -> Optional[Dict[str, Any]]:
self._force()
return self.response_json
def duration_ms(self) -> int:
self._force()
return int(((self._end or 0) - (self._start or 0)) * 1000)
def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""
def __iter__(self) -> Iterator[str]:
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
if self._done:
yield from self._chunks
return
for chunk in self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
):
yield chunk
self._chunks.append(chunk)
if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
class AsyncResponse(_BaseResponse):
model: "AsyncModel"
conversation: Optional["AsyncConversation"] = None
def __aiter__(self):
self._start = time.monotonic()
self._start_utcnow = datetime.datetime.utcnow()
return self
async def __anext__(self) -> str:
if self._done:
if not self._chunks:
raise StopAsyncIteration
chunk = self._chunks.pop(0)
if not self._chunks:
raise StopAsyncIteration
return chunk
if not hasattr(self, "_generator"):
self._generator = self.model.execute(
self.prompt,
stream=self.stream,
response=self,
conversation=self.conversation,
)
try:
chunk = await self._generator.__anext__()
self._chunks.append(chunk)
return chunk
except StopAsyncIteration:
if self.conversation:
self.conversation.responses.append(self)
self._end = time.monotonic()
self._done = True
raise
async def _force(self):
if not self._done:
async for _ in self:
pass
return self
async def text(self) -> str:
await self._force()
return "".join(self._chunks)
async def json(self) -> Optional[Dict[str, Any]]:
await self._force()
return self.response_json
async def duration_ms(self) -> int:
await self._force()
return int(((self._end or 0) - (self._start or 0)) * 1000)
async def datetime_utc(self) -> str:
await self._force()
return self._start_utcnow.isoformat() if self._start_utcnow else ""
def __await__(self):
return self._force().__await__()
@classmethod
def fake(
cls,
model: "Model",
model: "AsyncModel",
prompt: str,
*attachments: List[Attachment],
system: str,
response: str
response: str,
):
"Utility method to help with writing tests"
response_obj = cls(
@ -281,47 +436,11 @@ class Response(ABC):
response_obj._chunks = [response]
return response_obj
@classmethod
def from_row(cls, db, row):
from llm import get_model
model = get_model(row["model"])
response = cls(
model=model,
prompt=Prompt(
prompt=row["prompt"],
model=model,
attachments=[],
system=row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
)
response.id = row["id"]
response._prompt_json = json.loads(row["prompt_json"] or "null")
response.response_json = json.loads(row["response_json"] or "null")
response._done = True
response._chunks = [row["response"]]
# Attachments
response.attachments = [
Attachment.from_row(arow)
for arow in db.query(
"""
select attachments.* from attachments
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
where prompt_attachments.response_id = ?
order by prompt_attachments."order"
""",
[row["id"]],
)
]
return response
def __repr__(self):
return "<Response prompt='{}' text='{}'>".format(
self.prompt.prompt, self.text()
)
text = "... not yet awaited ..."
if self._done:
text = "".join(self._chunks)
return "<Response prompt='{}' text='{}'>".format(self.prompt.prompt, text)
class Options(BaseModel):
@ -362,22 +481,39 @@ class _get_key_mixin:
raise NeedsKeyException(message)
class Model(ABC, _get_key_mixin):
class _BaseModel(ABC, _get_key_mixin):
model_id: str
# API key handling
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
# Model characteristics
can_stream: bool = False
attachment_types: Set = set()
class Options(_Options):
pass
def conversation(self):
def _validate_attachments(
self, attachments: Optional[List[Attachment]] = None
) -> None:
if attachments and not self.attachment_types:
raise ValueError("This model does not support attachments")
for attachment in attachments or []:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
f"This model does not support attachments of type '{attachment_type}', "
f"only {', '.join(self.attachment_types)}"
)
def __str__(self) -> str:
return "{}: {}".format(self.__class__.__name__, self.model_id)
def __repr__(self):
return "<{} '{}'>".format(self.__class__.__name__, self.model_id)
class Model(_BaseModel):
def conversation(self) -> Conversation:
return Conversation(model=self)
@abstractmethod
@ -388,10 +524,6 @@ class Model(ABC, _get_key_mixin):
response: Response,
conversation: Optional[Conversation],
) -> Iterator[str]:
"""
Execute a prompt and yield chunks of text, or yield a single big chunk.
Any additional useful information about the execution should be assigned to the response.
"""
pass
def prompt(
@ -401,22 +533,10 @@ class Model(ABC, _get_key_mixin):
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options
):
# Validate attachments
if attachments and not self.attachment_types:
raise ValueError(
"This model does not support attachments, but some were provided"
)
for attachment in attachments or []:
attachment_type = attachment.resolve_type()
if attachment_type not in self.attachment_types:
raise ValueError(
"This model does not support attachments of type '{}', only {}".format(
attachment_type, ", ".join(self.attachment_types)
)
)
return self.response(
**options,
) -> Response:
self._validate_attachments(attachments)
return Response(
Prompt(
prompt,
attachments=attachments,
@ -424,17 +544,46 @@ class Model(ABC, _get_key_mixin):
model=self,
options=self.Options(**options),
),
stream=stream,
self,
stream,
)
def response(self, prompt: Prompt, stream: bool = True) -> Response:
return Response(prompt, self, stream)
def __str__(self) -> str:
return "{}: {}".format(self.__class__.__name__, self.model_id)
class AsyncModel(_BaseModel):
def conversation(self) -> AsyncConversation:
return AsyncConversation(model=self)
def __repr__(self):
return "<Model '{}'>".format(self.model_id)
@abstractmethod
async def execute(
self,
prompt: Prompt,
stream: bool,
response: AsyncResponse,
conversation: Optional[AsyncConversation],
) -> AsyncGenerator[str, None]:
yield ""
def prompt(
self,
prompt: str,
*,
attachments: Optional[List[Attachment]] = None,
system: Optional[str] = None,
stream: bool = True,
**options,
) -> AsyncResponse:
self._validate_attachments(attachments)
return AsyncResponse(
Prompt(
prompt,
attachments=attachments,
system=system,
model=self,
options=self.Options(**options),
),
self,
stream,
)
class EmbeddingModel(ABC, _get_key_mixin):
@ -495,6 +644,7 @@ class EmbeddingModel(ABC, _get_key_mixin):
@dataclass
class ModelWithAliases:
model: Model
async_model: AsyncModel
aliases: Set[str]

View file

@ -1,4 +1,5 @@
[pytest]
filterwarnings =
ignore:The `schema` method is deprecated.*:DeprecationWarning
ignore:Support for class-based `config` is deprecated*:DeprecationWarning
ignore:Support for class-based `config` is deprecated*:DeprecationWarning
asyncio_default_fixture_loop_scope = function

View file

@ -55,6 +55,7 @@ setup(
"pytest",
"numpy",
"pytest-httpx>=0.33.0",
"pytest-asyncio",
"cogapp",
"mypy>=1.10.0",
"black>=24.1.0",

View file

@ -75,6 +75,29 @@ class MockModel(llm.Model):
break
class AsyncMockModel(llm.AsyncModel):
model_id = "mock"
def __init__(self):
self.history = []
self._queue = []
def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)
async def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
while True:
try:
messages = self._queue.pop(0)
for message in messages:
yield message
break
except IndexError:
break
class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
@ -118,8 +141,13 @@ def mock_model():
return MockModel()
@pytest.fixture
def async_mock_model():
return AsyncMockModel()
@pytest.fixture(autouse=True)
def register_embed_demo_model(embed_demo, mock_model):
def register_embed_demo_model(embed_demo, mock_model, async_mock_model):
class MockModelsPlugin:
__name__ = "MockModelsPlugin"
@ -131,7 +159,7 @@ def register_embed_demo_model(embed_demo, mock_model):
@llm.hookimpl
def register_models(self, register):
register(mock_model)
register(mock_model, async_model=async_mock_model)
pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
try:

17
tests/test_async.py Normal file
View file

@ -0,0 +1,17 @@
import llm
import pytest
@pytest.mark.asyncio
async def test_async_model(async_mock_model):
gathered = []
async_mock_model.enqueue(["hello world"])
async for chunk in async_mock_model.prompt("hello"):
gathered.append(chunk)
assert gathered == ["hello world"]
# Not as an iterator
async_mock_model.enqueue(["hello world"])
response = await async_mock_model.prompt("hello")
text = await response.text()
assert text == "hello world"
assert isinstance(response, llm.AsyncResponse)

View file

@ -80,7 +80,10 @@ def test_chat_basic(mock_model, logs_db):
# Now continue that conversation
mock_model.enqueue(["continued"])
result2 = runner.invoke(
llm.cli.cli, ["chat", "-m", "mock", "-c"], input="Continue\nquit\n"
llm.cli.cli,
["chat", "-m", "mock", "-c"],
input="Continue\nquit\n",
catch_exceptions=False,
)
assert result2.exit_code == 0
assert result2.output == (
@ -176,7 +179,7 @@ def test_chat_options(mock_model, logs_db):
"response": "Some text",
"response_json": None,
"conversation_id": ANY,
"duration_ms": 0,
"duration_ms": ANY,
"datetime_utc": ANY,
}
]

View file

@ -555,6 +555,14 @@ def test_llm_models_options(user_path):
result = runner.invoke(cli, ["models", "--options"], catch_exceptions=False)
assert result.exit_code == 0
assert EXPECTED_OPTIONS.strip() in result.output
assert "AsyncMockModel: mock" not in result.output
def test_llm_models_async(user_path):
runner = CliRunner()
result = runner.invoke(cli, ["models", "--async"], catch_exceptions=False)
assert result.exit_code == 0
assert "AsyncMockModel: mock" in result.output
def test_llm_user_dir(tmpdir, monkeypatch):