mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
llm.get_async_model(), llm.AsyncModel base class and OpenAI async models (#613)
- https://github.com/simonw/llm/issues/507#issuecomment-2458639308 * register_model is now async aware Refs https://github.com/simonw/llm/issues/507#issuecomment-2458658134 * Refactor Chat and AsyncChat to use _Shared base class Refs https://github.com/simonw/llm/issues/507#issuecomment-2458692338 * fixed function name * Fix for infinite loop * Applied Black * Ran cog * Applied Black * Add Response.from_row() classmethod back again It does not matter that this is a blocking call, since it is a classmethod * Made mypy happy with llm/models.py * mypy fixes for openai_models.py I am unhappy with this, had to duplicate some code. * First test for AsyncModel * Still have not quite got this working * Fix for not loading plugins during tests, refs #626 * audio/wav not audio/wave, refs #603 * Black and mypy and ruff all happy * Refactor to avoid generics * Removed obsolete response() method * Support text = await async_mock_model.prompt("hello") * Initial docs for llm.get_async_model() and await model.prompt() Refs #507 * Initial async model plugin creation docs * duration_ms ANY to pass test * llm models --async option Refs https://github.com/simonw/llm/pull/613#issuecomment-2474724406 * Removed obsolete TypeVars * Expanded register_models() docs for async * await model.prompt() now returns AsyncResponse Refs https://github.com/simonw/llm/pull/613#issuecomment-2475157822 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
5a984d0c87
commit
ba75c674cb
14 changed files with 688 additions and 219 deletions
|
|
@ -121,6 +121,7 @@ Options:
|
|||
--cid, --conversation TEXT Continue the conversation with the given ID.
|
||||
--key TEXT API key to use
|
||||
--save TEXT Save prompt with this template name
|
||||
--async Run prompt asynchronously
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
|
@ -322,6 +323,7 @@ Usage: llm models list [OPTIONS]
|
|||
|
||||
Options:
|
||||
--options Show options for each model, if available
|
||||
--async List async models
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,64 @@ The {ref}`model plugin tutorial <tutorial-model-plugin>` covers the basics of de
|
|||
|
||||
This document covers more advanced topics.
|
||||
|
||||
(advanced-model-plugins-async)=
|
||||
|
||||
## Async models
|
||||
|
||||
Plugins can optionally provide an asynchronous version of their model, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). This is particularly useful for remote models accessible by an HTTP API.
|
||||
|
||||
The async version of a model subclasses `llm.AsyncModel` instead of `llm.Model`. It must implement an `async def execute()` async generator method instead of `def execute()`.
|
||||
|
||||
This example shows a subset of the OpenAI default plugin illustrating how this method might work:
|
||||
|
||||
|
||||
```python
|
||||
from typing import AsyncGenerator
|
||||
import llm
|
||||
|
||||
class MyAsyncModel(llm.AsyncModel):
|
||||
# This cn duplicate the model_id of the sync model:
|
||||
model_id = "my-model-id"
|
||||
|
||||
async def execute(
|
||||
self, prompt, stream, response, conversation=None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if stream:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in completion:
|
||||
yield chunk.choices[0].delta.content
|
||||
else:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name or self.model_id,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
yield completion.choices[0].message.content
|
||||
```
|
||||
This async model instance should then be passed to the `register()` method in the `register_models()` plugin hook:
|
||||
|
||||
```python
|
||||
@hookimpl
|
||||
def register_models(register):
|
||||
register(
|
||||
MyModel(), MyAsyncModel(), aliases=("my-model-aliases",)
|
||||
)
|
||||
```
|
||||
|
||||
(advanced-model-plugins-attachments)=
|
||||
|
||||
## Attachments for multi-modal models
|
||||
|
||||
Models such as GPT-4o, Claude 3.5 Sonnet and Google's Gemini 1.5 are multi-modal: they accept input in the form of images and maybe even audio, video and other formats.
|
||||
|
||||
LLM calls these **attachments**. Models can specify the types of attachments they accept and then implement special code in the `.execute()` method to handle them.
|
||||
|
||||
See {ref}`the Python attachments documentation <python-api-attachments>` for details on using attachments in the Python API.
|
||||
|
||||
### Specifying attachment types
|
||||
|
||||
A `Model` subclass can list the types of attachments it accepts by defining a `attachment_types` class attribute:
|
||||
|
|
|
|||
|
|
@ -42,5 +42,20 @@ class HelloWorld(llm.Model):
|
|||
def execute(self, prompt, stream, response):
|
||||
return ["hello world"]
|
||||
```
|
||||
If your model includes an async version, you can register that too:
|
||||
|
||||
```python
|
||||
class AsyncHelloWorld(llm.AsyncModel):
|
||||
model_id = "helloworld"
|
||||
|
||||
async def execute(self, prompt, stream, response):
|
||||
return ["hello world"]
|
||||
|
||||
@llm.hookimpl
|
||||
def register_models(register):
|
||||
register(HelloWorld(), AsyncHelloWorld(), aliases=("hw",))
|
||||
```
|
||||
This demonstrates how to register a model with both sync and async versions, and how to specify an alias for that model.
|
||||
|
||||
The {ref}`model plugin tutorial <tutorial-model-plugin>` describes how to use this hook in detail. Asynchronous models {ref}`are described here <advanced-model-plugins-async>`.
|
||||
|
||||
{ref}`tutorial-model-plugin` describes how to use this hook in detail.
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ print(response.text())
|
|||
```
|
||||
Some models do not use API keys at all.
|
||||
|
||||
## Streaming responses
|
||||
### Streaming responses
|
||||
|
||||
For models that support it you can stream responses as they are generated, like this:
|
||||
|
||||
|
|
@ -112,6 +112,34 @@ The `response.text()` method described earlier does this for you - it runs throu
|
|||
|
||||
If a response has been evaluated, `response.text()` will continue to return the same string.
|
||||
|
||||
(python-api-async)=
|
||||
|
||||
## Async models
|
||||
|
||||
Some plugins provide async versions of their supported models, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html).
|
||||
|
||||
To use an async model, use the `llm.get_async_model()` function instead of `llm.get_model()`:
|
||||
|
||||
```python
|
||||
import llm
|
||||
model = llm.get_async_model("gpt-4o")
|
||||
```
|
||||
You can then run a prompt using `await model.prompt(...)`:
|
||||
|
||||
```python
|
||||
response = await model.prompt(
|
||||
"Five surprising names for a pet pelican"
|
||||
)
|
||||
print(await response.text())
|
||||
```
|
||||
Or use `async for chunk in ...` to stream the response as it is generated:
|
||||
```python
|
||||
async for chunk in model.prompt(
|
||||
"Five surprising names for a pet pelican"
|
||||
):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
## Conversations
|
||||
|
||||
LLM supports *conversations*, where you ask follow-up questions of a model as part of an ongoing conversation.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from .errors import (
|
|||
NeedsKeyException,
|
||||
)
|
||||
from .models import (
|
||||
AsyncModel,
|
||||
AsyncResponse,
|
||||
Attachment,
|
||||
Conversation,
|
||||
Model,
|
||||
|
|
@ -26,9 +28,11 @@ import struct
|
|||
|
||||
__all__ = [
|
||||
"hookimpl",
|
||||
"get_async_model",
|
||||
"get_model",
|
||||
"get_key",
|
||||
"user_dir",
|
||||
"AsyncResponse",
|
||||
"Attachment",
|
||||
"Collection",
|
||||
"Conversation",
|
||||
|
|
@ -74,11 +78,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
|
|||
for alias, model_id in configured_aliases.items():
|
||||
extra_model_aliases.setdefault(model_id, []).append(alias)
|
||||
|
||||
def register(model, aliases=None):
|
||||
def register(model, async_model=None, aliases=None):
|
||||
alias_list = list(aliases or [])
|
||||
if model.model_id in extra_model_aliases:
|
||||
alias_list.extend(extra_model_aliases[model.model_id])
|
||||
model_aliases.append(ModelWithAliases(model, alias_list))
|
||||
model_aliases.append(ModelWithAliases(model, async_model, alias_list))
|
||||
|
||||
load_plugins()
|
||||
pm.hook.register_models(register=register)
|
||||
|
|
@ -137,12 +141,25 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
|
|||
return model_aliases
|
||||
|
||||
|
||||
def get_async_model_aliases() -> Dict[str, AsyncModel]:
|
||||
async_model_aliases = {}
|
||||
for model_with_aliases in get_models_with_aliases():
|
||||
if model_with_aliases.async_model:
|
||||
for alias in model_with_aliases.aliases:
|
||||
async_model_aliases[alias] = model_with_aliases.async_model
|
||||
async_model_aliases[model_with_aliases.model.model_id] = (
|
||||
model_with_aliases.async_model
|
||||
)
|
||||
return async_model_aliases
|
||||
|
||||
|
||||
def get_model_aliases() -> Dict[str, Model]:
|
||||
model_aliases = {}
|
||||
for model_with_aliases in get_models_with_aliases():
|
||||
for alias in model_with_aliases.aliases:
|
||||
model_aliases[alias] = model_with_aliases.model
|
||||
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
|
||||
if model_with_aliases.model:
|
||||
for alias in model_with_aliases.aliases:
|
||||
model_aliases[alias] = model_with_aliases.model
|
||||
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
|
||||
return model_aliases
|
||||
|
||||
|
||||
|
|
@ -150,13 +167,42 @@ class UnknownModelError(KeyError):
|
|||
pass
|
||||
|
||||
|
||||
def get_model(name: Optional[str] = None) -> Model:
|
||||
def get_async_model(name: Optional[str] = None) -> AsyncModel:
|
||||
aliases = get_async_model_aliases()
|
||||
name = name or get_default_model()
|
||||
try:
|
||||
return aliases[name]
|
||||
except KeyError:
|
||||
# Does a sync model exist?
|
||||
sync_model = None
|
||||
try:
|
||||
sync_model = get_model(name, _skip_async=True)
|
||||
except UnknownModelError:
|
||||
pass
|
||||
if sync_model:
|
||||
raise UnknownModelError("Unknown async model (sync model exists): " + name)
|
||||
else:
|
||||
raise UnknownModelError("Unknown model: " + name)
|
||||
|
||||
|
||||
def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model:
|
||||
aliases = get_model_aliases()
|
||||
name = name or get_default_model()
|
||||
try:
|
||||
return aliases[name]
|
||||
except KeyError:
|
||||
raise UnknownModelError("Unknown model: " + name)
|
||||
# Does an async model exist?
|
||||
if _skip_async:
|
||||
raise UnknownModelError("Unknown model: " + name)
|
||||
async_model = None
|
||||
try:
|
||||
async_model = get_async_model(name)
|
||||
except UnknownModelError:
|
||||
pass
|
||||
if async_model:
|
||||
raise UnknownModelError("Unknown model (async model exists): " + name)
|
||||
else:
|
||||
raise UnknownModelError("Unknown model: " + name)
|
||||
|
||||
|
||||
def get_key(
|
||||
|
|
|
|||
69
llm/cli.py
69
llm/cli.py
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import click
|
||||
from click_default_group import DefaultGroup
|
||||
from dataclasses import asdict
|
||||
|
|
@ -11,6 +12,7 @@ from llm import (
|
|||
Template,
|
||||
UnknownModelError,
|
||||
encode,
|
||||
get_async_model,
|
||||
get_default_model,
|
||||
get_default_embedding_model,
|
||||
get_embedding_models_with_aliases,
|
||||
|
|
@ -199,6 +201,7 @@ def cli():
|
|||
)
|
||||
@click.option("--key", help="API key to use")
|
||||
@click.option("--save", help="Save prompt with this template name")
|
||||
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
|
||||
def prompt(
|
||||
prompt,
|
||||
system,
|
||||
|
|
@ -215,6 +218,7 @@ def prompt(
|
|||
conversation_id,
|
||||
key,
|
||||
save,
|
||||
async_,
|
||||
):
|
||||
"""
|
||||
Execute a prompt
|
||||
|
|
@ -337,9 +341,12 @@ def prompt(
|
|||
|
||||
# Now resolve the model
|
||||
try:
|
||||
model = model_aliases[model_id]
|
||||
except KeyError:
|
||||
raise click.ClickException("'{}' is not a known model".format(model_id))
|
||||
if async_:
|
||||
model = get_async_model(model_id)
|
||||
else:
|
||||
model = get_model(model_id)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(ex)
|
||||
|
||||
# Provide the API key, if one is needed and has been provided
|
||||
if model.needs_key:
|
||||
|
|
@ -375,21 +382,48 @@ def prompt(
|
|||
prompt_method = conversation.prompt
|
||||
|
||||
try:
|
||||
response = prompt_method(
|
||||
prompt, attachments=resolved_attachments, system=system, **validated_options
|
||||
)
|
||||
if should_stream:
|
||||
for chunk in response:
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
if async_:
|
||||
|
||||
async def inner():
|
||||
if should_stream:
|
||||
async for chunk in prompt_method(
|
||||
prompt,
|
||||
attachments=resolved_attachments,
|
||||
system=system,
|
||||
**validated_options,
|
||||
):
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
else:
|
||||
response = prompt_method(
|
||||
prompt,
|
||||
attachments=resolved_attachments,
|
||||
system=system,
|
||||
**validated_options,
|
||||
)
|
||||
print(await response.text())
|
||||
|
||||
asyncio.run(inner())
|
||||
else:
|
||||
print(response.text())
|
||||
response = prompt_method(
|
||||
prompt,
|
||||
attachments=resolved_attachments,
|
||||
system=system,
|
||||
**validated_options,
|
||||
)
|
||||
if should_stream:
|
||||
for chunk in response:
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
else:
|
||||
print(response.text())
|
||||
except Exception as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
# Log to the database
|
||||
if (logs_on() or log) and not no_log:
|
||||
if (logs_on() or log) and not no_log and not async_:
|
||||
log_path = logs_db_path()
|
||||
(log_path.parent).mkdir(parents=True, exist_ok=True)
|
||||
db = sqlite_utils.Database(log_path)
|
||||
|
|
@ -981,14 +1015,19 @@ _type_lookup = {
|
|||
@click.option(
|
||||
"--options", is_flag=True, help="Show options for each model, if available"
|
||||
)
|
||||
def models_list(options):
|
||||
@click.option("async_", "--async", is_flag=True, help="List async models")
|
||||
def models_list(options, async_):
|
||||
"List available models"
|
||||
models_that_have_shown_options = set()
|
||||
for model_with_aliases in get_models_with_aliases():
|
||||
if async_ and not model_with_aliases.async_model:
|
||||
continue
|
||||
extra = ""
|
||||
if model_with_aliases.aliases:
|
||||
extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases))
|
||||
model = model_with_aliases.model
|
||||
model = (
|
||||
model_with_aliases.model if not async_ else model_with_aliases.async_model
|
||||
)
|
||||
output = str(model) + extra
|
||||
if options and model.Options.schema()["properties"]:
|
||||
output += "\n Options:"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from llm import EmbeddingModel, Model, hookimpl
|
||||
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
|
||||
import llm
|
||||
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
|
||||
import click
|
||||
|
|
@ -16,7 +16,7 @@ except ImportError:
|
|||
from pydantic.fields import Field
|
||||
from pydantic.class_validators import validator as field_validator # type: ignore [no-redef]
|
||||
|
||||
from typing import List, Iterable, Iterator, Optional, Union
|
||||
from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union
|
||||
import json
|
||||
import yaml
|
||||
|
||||
|
|
@ -24,22 +24,47 @@ import yaml
|
|||
@hookimpl
|
||||
def register_models(register):
|
||||
# GPT-4o
|
||||
register(Chat("gpt-4o", vision=True), aliases=("4o",))
|
||||
register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",))
|
||||
register(Chat("gpt-4o-audio-preview", audio=True))
|
||||
register(
|
||||
Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",)
|
||||
)
|
||||
register(
|
||||
Chat("gpt-4o-mini", vision=True),
|
||||
AsyncChat("gpt-4o-mini", vision=True),
|
||||
aliases=("4o-mini",),
|
||||
)
|
||||
register(
|
||||
Chat("gpt-4o-audio-preview", audio=True),
|
||||
AsyncChat("gpt-4o-audio-preview", audio=True),
|
||||
)
|
||||
# 3.5 and 4
|
||||
register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt"))
|
||||
register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k"))
|
||||
register(Chat("gpt-4"), aliases=("4", "gpt4"))
|
||||
register(Chat("gpt-4-32k"), aliases=("4-32k",))
|
||||
register(
|
||||
Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")
|
||||
)
|
||||
register(
|
||||
Chat("gpt-3.5-turbo-16k"),
|
||||
AsyncChat("gpt-3.5-turbo-16k"),
|
||||
aliases=("chatgpt-16k", "3.5-16k"),
|
||||
)
|
||||
register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4"))
|
||||
register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",))
|
||||
# GPT-4 Turbo models
|
||||
register(Chat("gpt-4-1106-preview"))
|
||||
register(Chat("gpt-4-0125-preview"))
|
||||
register(Chat("gpt-4-turbo-2024-04-09"))
|
||||
register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t"))
|
||||
register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview"))
|
||||
register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview"))
|
||||
register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09"))
|
||||
register(
|
||||
Chat("gpt-4-turbo"),
|
||||
AsyncChat("gpt-4-turbo"),
|
||||
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
|
||||
)
|
||||
# o1
|
||||
register(Chat("o1-preview", can_stream=False, allows_system_prompt=False))
|
||||
register(Chat("o1-mini", can_stream=False, allows_system_prompt=False))
|
||||
register(
|
||||
Chat("o1-preview", can_stream=False, allows_system_prompt=False),
|
||||
AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False),
|
||||
)
|
||||
register(
|
||||
Chat("o1-mini", can_stream=False, allows_system_prompt=False),
|
||||
AsyncChat("o1-mini", can_stream=False, allows_system_prompt=False),
|
||||
)
|
||||
# The -instruct completion model
|
||||
register(
|
||||
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
|
||||
|
|
@ -273,18 +298,7 @@ def _attachment(attachment):
|
|||
}
|
||||
|
||||
|
||||
class Chat(Model):
|
||||
needs_key = "openai"
|
||||
key_env_var = "OPENAI_API_KEY"
|
||||
|
||||
default_max_tokens = None
|
||||
|
||||
class Options(SharedOptions):
|
||||
json_object: Optional[bool] = Field(
|
||||
description="Output a valid JSON object {...}. Prompt must mention JSON.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class _Shared:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
|
|
@ -335,10 +349,8 @@ class Chat(Model):
|
|||
def __str__(self):
|
||||
return "OpenAI Chat: {}".format(self.model_id)
|
||||
|
||||
def execute(self, prompt, stream, response, conversation=None):
|
||||
def build_messages(self, prompt, conversation):
|
||||
messages = []
|
||||
if prompt.system and not self.allows_system_prompt:
|
||||
raise NotImplementedError("Model does not support system prompts")
|
||||
current_system = None
|
||||
if conversation is not None:
|
||||
for prev_response in conversation.responses:
|
||||
|
|
@ -375,7 +387,60 @@ class Chat(Model):
|
|||
for attachment in prompt.attachments:
|
||||
attachment_message.append(_attachment(attachment))
|
||||
messages.append({"role": "user", "content": attachment_message})
|
||||
return messages
|
||||
|
||||
def get_client(self, async_=False):
|
||||
kwargs = {}
|
||||
if self.api_base:
|
||||
kwargs["base_url"] = self.api_base
|
||||
if self.api_type:
|
||||
kwargs["api_type"] = self.api_type
|
||||
if self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_engine:
|
||||
kwargs["engine"] = self.api_engine
|
||||
if self.needs_key:
|
||||
kwargs["api_key"] = self.get_key()
|
||||
else:
|
||||
# OpenAI-compatible models don't need a key, but the
|
||||
# openai client library requires one
|
||||
kwargs["api_key"] = "DUMMY_KEY"
|
||||
if self.headers:
|
||||
kwargs["default_headers"] = self.headers
|
||||
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
|
||||
kwargs["http_client"] = logging_client()
|
||||
if async_:
|
||||
return openai.AsyncOpenAI(**kwargs)
|
||||
else:
|
||||
return openai.OpenAI(**kwargs)
|
||||
|
||||
def build_kwargs(self, prompt, stream):
|
||||
kwargs = dict(not_nulls(prompt.options))
|
||||
json_object = kwargs.pop("json_object", None)
|
||||
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
|
||||
kwargs["max_tokens"] = self.default_max_tokens
|
||||
if json_object:
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
if stream:
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
return kwargs
|
||||
|
||||
|
||||
class Chat(_Shared, Model):
|
||||
needs_key = "openai"
|
||||
key_env_var = "OPENAI_API_KEY"
|
||||
default_max_tokens = None
|
||||
|
||||
class Options(SharedOptions):
|
||||
json_object: Optional[bool] = Field(
|
||||
description="Output a valid JSON object {...}. Prompt must mention JSON.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def execute(self, prompt, stream, response, conversation=None):
|
||||
if prompt.system and not self.allows_system_prompt:
|
||||
raise NotImplementedError("Model does not support system prompts")
|
||||
messages = self.build_messages(prompt, conversation)
|
||||
kwargs = self.build_kwargs(prompt, stream)
|
||||
client = self.get_client()
|
||||
if stream:
|
||||
|
|
@ -406,38 +471,53 @@ class Chat(Model):
|
|||
yield completion.choices[0].message.content
|
||||
response._prompt_json = redact_data({"messages": messages})
|
||||
|
||||
def get_client(self):
|
||||
kwargs = {}
|
||||
if self.api_base:
|
||||
kwargs["base_url"] = self.api_base
|
||||
if self.api_type:
|
||||
kwargs["api_type"] = self.api_type
|
||||
if self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_engine:
|
||||
kwargs["engine"] = self.api_engine
|
||||
if self.needs_key:
|
||||
kwargs["api_key"] = self.get_key()
|
||||
else:
|
||||
# OpenAI-compatible models don't need a key, but the
|
||||
# openai client library requires one
|
||||
kwargs["api_key"] = "DUMMY_KEY"
|
||||
if self.headers:
|
||||
kwargs["default_headers"] = self.headers
|
||||
if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"):
|
||||
kwargs["http_client"] = logging_client()
|
||||
return openai.OpenAI(**kwargs)
|
||||
|
||||
def build_kwargs(self, prompt, stream):
|
||||
kwargs = dict(not_nulls(prompt.options))
|
||||
json_object = kwargs.pop("json_object", None)
|
||||
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
|
||||
kwargs["max_tokens"] = self.default_max_tokens
|
||||
if json_object:
|
||||
kwargs["response_format"] = {"type": "json_object"}
|
||||
class AsyncChat(_Shared, AsyncModel):
|
||||
needs_key = "openai"
|
||||
key_env_var = "OPENAI_API_KEY"
|
||||
default_max_tokens = None
|
||||
|
||||
class Options(SharedOptions):
|
||||
json_object: Optional[bool] = Field(
|
||||
description="Output a valid JSON object {...}. Prompt must mention JSON.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self, prompt, stream, response, conversation=None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if prompt.system and not self.allows_system_prompt:
|
||||
raise NotImplementedError("Model does not support system prompts")
|
||||
messages = self.build_messages(prompt, conversation)
|
||||
kwargs = self.build_kwargs(prompt, stream)
|
||||
client = self.get_client(async_=True)
|
||||
if stream:
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
return kwargs
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name or self.model_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
chunks = []
|
||||
async for chunk in completion:
|
||||
chunks.append(chunk)
|
||||
try:
|
||||
content = chunk.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
if content is not None:
|
||||
yield content
|
||||
response.response_json = remove_dict_none_values(combine_chunks(chunks))
|
||||
else:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self.model_name or self.model_id,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
)
|
||||
response.response_json = remove_dict_none_values(completion.model_dump())
|
||||
yield completion.choices[0].message.content
|
||||
response._prompt_json = redact_data({"messages": messages})
|
||||
|
||||
|
||||
class Completion(Chat):
|
||||
|
|
|
|||
410
llm/models.py
410
llm/models.py
|
|
@ -7,7 +7,17 @@ import httpx
|
|||
from itertools import islice
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
from .utils import mimetype_from_path, mimetype_from_string
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
|
|
@ -94,7 +104,7 @@ class Prompt:
|
|||
attachments=None,
|
||||
system=None,
|
||||
prompt_json=None,
|
||||
options=None
|
||||
options=None,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
|
|
@ -105,12 +115,25 @@ class Prompt:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
model: "Model"
|
||||
class _BaseConversation:
|
||||
model: "_BaseModel"
|
||||
id: str = field(default_factory=lambda: str(ULID()).lower())
|
||||
name: Optional[str] = None
|
||||
responses: List["Response"] = field(default_factory=list)
|
||||
responses: List["_BaseResponse"] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_model
|
||||
|
||||
return cls(
|
||||
model=get_model(row["model"]),
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation(_BaseConversation):
|
||||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
|
|
@ -118,8 +141,8 @@ class Conversation:
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options
|
||||
):
|
||||
**options,
|
||||
) -> "Response":
|
||||
return Response(
|
||||
Prompt(
|
||||
prompt,
|
||||
|
|
@ -133,24 +156,45 @@ class Conversation:
|
|||
conversation=self,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_model
|
||||
|
||||
return cls(
|
||||
model=get_model(row["model"]),
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
@dataclass
|
||||
class AsyncConversation(_BaseConversation):
|
||||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
*,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options,
|
||||
) -> "AsyncResponse":
|
||||
return AsyncResponse(
|
||||
Prompt(
|
||||
prompt,
|
||||
model=self.model,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
options=self.model.Options(**options),
|
||||
),
|
||||
self.model,
|
||||
stream,
|
||||
conversation=self,
|
||||
)
|
||||
|
||||
|
||||
class Response(ABC):
|
||||
class _BaseResponse:
|
||||
"""Base response class shared between sync and async responses"""
|
||||
|
||||
prompt: "Prompt"
|
||||
stream: bool
|
||||
conversation: Optional["_BaseConversation"] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
model: "Model",
|
||||
model: "_BaseModel",
|
||||
stream: bool,
|
||||
conversation: Optional[Conversation] = None,
|
||||
conversation: Optional[_BaseConversation] = None,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self._prompt_json = None
|
||||
|
|
@ -161,47 +205,46 @@ class Response(ABC):
|
|||
self.response_json = None
|
||||
self.conversation = conversation
|
||||
self.attachments: List[Attachment] = []
|
||||
self._start: Optional[float] = None
|
||||
self._end: Optional[float] = None
|
||||
self._start_utcnow: Optional[datetime.datetime] = None
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
self._start = time.monotonic()
|
||||
self._start_utcnow = datetime.datetime.utcnow()
|
||||
if self._done:
|
||||
yield from self._chunks
|
||||
for chunk in self.model.execute(
|
||||
self.prompt,
|
||||
stream=self.stream,
|
||||
response=self,
|
||||
conversation=self.conversation,
|
||||
):
|
||||
yield chunk
|
||||
self._chunks.append(chunk)
|
||||
if self.conversation:
|
||||
self.conversation.responses.append(self)
|
||||
self._end = time.monotonic()
|
||||
self._done = True
|
||||
@classmethod
|
||||
def from_row(cls, db, row):
|
||||
from llm import get_model
|
||||
|
||||
def _force(self):
|
||||
if not self._done:
|
||||
list(self)
|
||||
model = get_model(row["model"])
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text()
|
||||
|
||||
def text(self) -> str:
|
||||
self._force()
|
||||
return "".join(self._chunks)
|
||||
|
||||
def json(self) -> Optional[Dict[str, Any]]:
|
||||
self._force()
|
||||
return self.response_json
|
||||
|
||||
def duration_ms(self) -> int:
|
||||
self._force()
|
||||
return int((self._end - self._start) * 1000)
|
||||
|
||||
def datetime_utc(self) -> str:
|
||||
self._force()
|
||||
return self._start_utcnow.isoformat()
|
||||
response = cls(
|
||||
model=model,
|
||||
prompt=Prompt(
|
||||
prompt=row["prompt"],
|
||||
model=model,
|
||||
attachments=[],
|
||||
system=row["system"],
|
||||
options=model.Options(**json.loads(row["options_json"])),
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
response.id = row["id"]
|
||||
response._prompt_json = json.loads(row["prompt_json"] or "null")
|
||||
response.response_json = json.loads(row["response_json"] or "null")
|
||||
response._done = True
|
||||
response._chunks = [row["response"]]
|
||||
# Attachments
|
||||
response.attachments = [
|
||||
Attachment.from_row(arow)
|
||||
for arow in db.query(
|
||||
"""
|
||||
select attachments.* from attachments
|
||||
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
|
||||
where prompt_attachments.response_id = ?
|
||||
order by prompt_attachments."order"
|
||||
""",
|
||||
[row["id"]],
|
||||
)
|
||||
]
|
||||
return response
|
||||
|
||||
def log_to_db(self, db):
|
||||
conversation = self.conversation
|
||||
|
|
@ -257,14 +300,126 @@ class Response(ABC):
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class Response(_BaseResponse):
|
||||
model: "Model"
|
||||
conversation: Optional["Conversation"] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text()
|
||||
|
||||
def _force(self):
|
||||
if not self._done:
|
||||
list(self)
|
||||
|
||||
def text(self) -> str:
|
||||
self._force()
|
||||
return "".join(self._chunks)
|
||||
|
||||
def json(self) -> Optional[Dict[str, Any]]:
|
||||
self._force()
|
||||
return self.response_json
|
||||
|
||||
def duration_ms(self) -> int:
|
||||
self._force()
|
||||
return int(((self._end or 0) - (self._start or 0)) * 1000)
|
||||
|
||||
def datetime_utc(self) -> str:
|
||||
self._force()
|
||||
return self._start_utcnow.isoformat() if self._start_utcnow else ""
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
self._start = time.monotonic()
|
||||
self._start_utcnow = datetime.datetime.utcnow()
|
||||
if self._done:
|
||||
yield from self._chunks
|
||||
return
|
||||
|
||||
for chunk in self.model.execute(
|
||||
self.prompt,
|
||||
stream=self.stream,
|
||||
response=self,
|
||||
conversation=self.conversation,
|
||||
):
|
||||
yield chunk
|
||||
self._chunks.append(chunk)
|
||||
|
||||
if self.conversation:
|
||||
self.conversation.responses.append(self)
|
||||
self._end = time.monotonic()
|
||||
self._done = True
|
||||
|
||||
|
||||
class AsyncResponse(_BaseResponse):
|
||||
model: "AsyncModel"
|
||||
conversation: Optional["AsyncConversation"] = None
|
||||
|
||||
def __aiter__(self):
|
||||
self._start = time.monotonic()
|
||||
self._start_utcnow = datetime.datetime.utcnow()
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> str:
|
||||
if self._done:
|
||||
if not self._chunks:
|
||||
raise StopAsyncIteration
|
||||
chunk = self._chunks.pop(0)
|
||||
if not self._chunks:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
if not hasattr(self, "_generator"):
|
||||
self._generator = self.model.execute(
|
||||
self.prompt,
|
||||
stream=self.stream,
|
||||
response=self,
|
||||
conversation=self.conversation,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk = await self._generator.__anext__()
|
||||
self._chunks.append(chunk)
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
if self.conversation:
|
||||
self.conversation.responses.append(self)
|
||||
self._end = time.monotonic()
|
||||
self._done = True
|
||||
raise
|
||||
|
||||
async def _force(self):
|
||||
if not self._done:
|
||||
async for _ in self:
|
||||
pass
|
||||
return self
|
||||
|
||||
async def text(self) -> str:
|
||||
await self._force()
|
||||
return "".join(self._chunks)
|
||||
|
||||
async def json(self) -> Optional[Dict[str, Any]]:
|
||||
await self._force()
|
||||
return self.response_json
|
||||
|
||||
async def duration_ms(self) -> int:
|
||||
await self._force()
|
||||
return int(((self._end or 0) - (self._start or 0)) * 1000)
|
||||
|
||||
async def datetime_utc(self) -> str:
|
||||
await self._force()
|
||||
return self._start_utcnow.isoformat() if self._start_utcnow else ""
|
||||
|
||||
def __await__(self):
|
||||
return self._force().__await__()
|
||||
|
||||
@classmethod
|
||||
def fake(
|
||||
cls,
|
||||
model: "Model",
|
||||
model: "AsyncModel",
|
||||
prompt: str,
|
||||
*attachments: List[Attachment],
|
||||
system: str,
|
||||
response: str
|
||||
response: str,
|
||||
):
|
||||
"Utility method to help with writing tests"
|
||||
response_obj = cls(
|
||||
|
|
@ -281,47 +436,11 @@ class Response(ABC):
|
|||
response_obj._chunks = [response]
|
||||
return response_obj
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, db, row):
|
||||
from llm import get_model
|
||||
|
||||
model = get_model(row["model"])
|
||||
|
||||
response = cls(
|
||||
model=model,
|
||||
prompt=Prompt(
|
||||
prompt=row["prompt"],
|
||||
model=model,
|
||||
attachments=[],
|
||||
system=row["system"],
|
||||
options=model.Options(**json.loads(row["options_json"])),
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
response.id = row["id"]
|
||||
response._prompt_json = json.loads(row["prompt_json"] or "null")
|
||||
response.response_json = json.loads(row["response_json"] or "null")
|
||||
response._done = True
|
||||
response._chunks = [row["response"]]
|
||||
# Attachments
|
||||
response.attachments = [
|
||||
Attachment.from_row(arow)
|
||||
for arow in db.query(
|
||||
"""
|
||||
select attachments.* from attachments
|
||||
join prompt_attachments on attachments.id = prompt_attachments.attachment_id
|
||||
where prompt_attachments.response_id = ?
|
||||
order by prompt_attachments."order"
|
||||
""",
|
||||
[row["id"]],
|
||||
)
|
||||
]
|
||||
return response
|
||||
|
||||
def __repr__(self):
|
||||
return "<Response prompt='{}' text='{}'>".format(
|
||||
self.prompt.prompt, self.text()
|
||||
)
|
||||
text = "... not yet awaited ..."
|
||||
if self._done:
|
||||
text = "".join(self._chunks)
|
||||
return "<Response prompt='{}' text='{}'>".format(self.prompt.prompt, text)
|
||||
|
||||
|
||||
class Options(BaseModel):
|
||||
|
|
@ -362,22 +481,39 @@ class _get_key_mixin:
|
|||
raise NeedsKeyException(message)
|
||||
|
||||
|
||||
class Model(ABC, _get_key_mixin):
|
||||
class _BaseModel(ABC, _get_key_mixin):
|
||||
model_id: str
|
||||
|
||||
# API key handling
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
|
||||
# Model characteristics
|
||||
can_stream: bool = False
|
||||
attachment_types: Set = set()
|
||||
|
||||
class Options(_Options):
|
||||
pass
|
||||
|
||||
def conversation(self):
|
||||
def _validate_attachments(
|
||||
self, attachments: Optional[List[Attachment]] = None
|
||||
) -> None:
|
||||
if attachments and not self.attachment_types:
|
||||
raise ValueError("This model does not support attachments")
|
||||
for attachment in attachments or []:
|
||||
attachment_type = attachment.resolve_type()
|
||||
if attachment_type not in self.attachment_types:
|
||||
raise ValueError(
|
||||
f"This model does not support attachments of type '{attachment_type}', "
|
||||
f"only {', '.join(self.attachment_types)}"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}: {}".format(self.__class__.__name__, self.model_id)
|
||||
|
||||
def __repr__(self):
|
||||
return "<{} '{}'>".format(self.__class__.__name__, self.model_id)
|
||||
|
||||
|
||||
class Model(_BaseModel):
|
||||
def conversation(self) -> Conversation:
|
||||
return Conversation(model=self)
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -388,10 +524,6 @@ class Model(ABC, _get_key_mixin):
|
|||
response: Response,
|
||||
conversation: Optional[Conversation],
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Execute a prompt and yield chunks of text, or yield a single big chunk.
|
||||
Any additional useful information about the execution should be assigned to the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prompt(
|
||||
|
|
@ -401,22 +533,10 @@ class Model(ABC, _get_key_mixin):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options
|
||||
):
|
||||
# Validate attachments
|
||||
if attachments and not self.attachment_types:
|
||||
raise ValueError(
|
||||
"This model does not support attachments, but some were provided"
|
||||
)
|
||||
for attachment in attachments or []:
|
||||
attachment_type = attachment.resolve_type()
|
||||
if attachment_type not in self.attachment_types:
|
||||
raise ValueError(
|
||||
"This model does not support attachments of type '{}', only {}".format(
|
||||
attachment_type, ", ".join(self.attachment_types)
|
||||
)
|
||||
)
|
||||
return self.response(
|
||||
**options,
|
||||
) -> Response:
|
||||
self._validate_attachments(attachments)
|
||||
return Response(
|
||||
Prompt(
|
||||
prompt,
|
||||
attachments=attachments,
|
||||
|
|
@ -424,17 +544,46 @@ class Model(ABC, _get_key_mixin):
|
|||
model=self,
|
||||
options=self.Options(**options),
|
||||
),
|
||||
stream=stream,
|
||||
self,
|
||||
stream,
|
||||
)
|
||||
|
||||
def response(self, prompt: Prompt, stream: bool = True) -> Response:
|
||||
return Response(prompt, self, stream)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}: {}".format(self.__class__.__name__, self.model_id)
|
||||
class AsyncModel(_BaseModel):
|
||||
def conversation(self) -> AsyncConversation:
|
||||
return AsyncConversation(model=self)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Model '{}'>".format(self.model_id)
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
stream: bool,
|
||||
response: AsyncResponse,
|
||||
conversation: Optional[AsyncConversation],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
yield ""
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
system: Optional[str] = None,
|
||||
stream: bool = True,
|
||||
**options,
|
||||
) -> AsyncResponse:
|
||||
self._validate_attachments(attachments)
|
||||
return AsyncResponse(
|
||||
Prompt(
|
||||
prompt,
|
||||
attachments=attachments,
|
||||
system=system,
|
||||
model=self,
|
||||
options=self.Options(**options),
|
||||
),
|
||||
self,
|
||||
stream,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingModel(ABC, _get_key_mixin):
|
||||
|
|
@ -495,6 +644,7 @@ class EmbeddingModel(ABC, _get_key_mixin):
|
|||
@dataclass
|
||||
class ModelWithAliases:
|
||||
model: Model
|
||||
async_model: AsyncModel
|
||||
aliases: Set[str]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
[pytest]
|
||||
filterwarnings =
|
||||
ignore:The `schema` method is deprecated.*:DeprecationWarning
|
||||
ignore:Support for class-based `config` is deprecated*:DeprecationWarning
|
||||
ignore:Support for class-based `config` is deprecated*:DeprecationWarning
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
1
setup.py
1
setup.py
|
|
@ -55,6 +55,7 @@ setup(
|
|||
"pytest",
|
||||
"numpy",
|
||||
"pytest-httpx>=0.33.0",
|
||||
"pytest-asyncio",
|
||||
"cogapp",
|
||||
"mypy>=1.10.0",
|
||||
"black>=24.1.0",
|
||||
|
|
|
|||
|
|
@ -75,6 +75,29 @@ class MockModel(llm.Model):
|
|||
break
|
||||
|
||||
|
||||
class AsyncMockModel(llm.AsyncModel):
|
||||
model_id = "mock"
|
||||
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
self._queue = []
|
||||
|
||||
def enqueue(self, messages):
|
||||
assert isinstance(messages, list)
|
||||
self._queue.append(messages)
|
||||
|
||||
async def execute(self, prompt, stream, response, conversation):
|
||||
self.history.append((prompt, stream, response, conversation))
|
||||
while True:
|
||||
try:
|
||||
messages = self._queue.pop(0)
|
||||
for message in messages:
|
||||
yield message
|
||||
break
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
|
||||
class EmbedDemo(llm.EmbeddingModel):
|
||||
model_id = "embed-demo"
|
||||
batch_size = 10
|
||||
|
|
@ -118,8 +141,13 @@ def mock_model():
|
|||
return MockModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_mock_model():
|
||||
return AsyncMockModel()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def register_embed_demo_model(embed_demo, mock_model):
|
||||
def register_embed_demo_model(embed_demo, mock_model, async_mock_model):
|
||||
class MockModelsPlugin:
|
||||
__name__ = "MockModelsPlugin"
|
||||
|
||||
|
|
@ -131,7 +159,7 @@ def register_embed_demo_model(embed_demo, mock_model):
|
|||
|
||||
@llm.hookimpl
|
||||
def register_models(self, register):
|
||||
register(mock_model)
|
||||
register(mock_model, async_model=async_mock_model)
|
||||
|
||||
pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
|
||||
try:
|
||||
|
|
|
|||
17
tests/test_async.py
Normal file
17
tests/test_async.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import llm
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_model(async_mock_model):
|
||||
gathered = []
|
||||
async_mock_model.enqueue(["hello world"])
|
||||
async for chunk in async_mock_model.prompt("hello"):
|
||||
gathered.append(chunk)
|
||||
assert gathered == ["hello world"]
|
||||
# Not as an iterator
|
||||
async_mock_model.enqueue(["hello world"])
|
||||
response = await async_mock_model.prompt("hello")
|
||||
text = await response.text()
|
||||
assert text == "hello world"
|
||||
assert isinstance(response, llm.AsyncResponse)
|
||||
|
|
@ -80,7 +80,10 @@ def test_chat_basic(mock_model, logs_db):
|
|||
# Now continue that conversation
|
||||
mock_model.enqueue(["continued"])
|
||||
result2 = runner.invoke(
|
||||
llm.cli.cli, ["chat", "-m", "mock", "-c"], input="Continue\nquit\n"
|
||||
llm.cli.cli,
|
||||
["chat", "-m", "mock", "-c"],
|
||||
input="Continue\nquit\n",
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result2.exit_code == 0
|
||||
assert result2.output == (
|
||||
|
|
@ -176,7 +179,7 @@ def test_chat_options(mock_model, logs_db):
|
|||
"response": "Some text",
|
||||
"response_json": None,
|
||||
"conversation_id": ANY,
|
||||
"duration_ms": 0,
|
||||
"duration_ms": ANY,
|
||||
"datetime_utc": ANY,
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -555,6 +555,14 @@ def test_llm_models_options(user_path):
|
|||
result = runner.invoke(cli, ["models", "--options"], catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
assert EXPECTED_OPTIONS.strip() in result.output
|
||||
assert "AsyncMockModel: mock" not in result.output
|
||||
|
||||
|
||||
def test_llm_models_async(user_path):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["models", "--async"], catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
assert "AsyncMockModel: mock" in result.output
|
||||
|
||||
|
||||
def test_llm_user_dir(tmpdir, monkeypatch):
|
||||
|
|
|
|||
Loading…
Reference in a new issue