Made load_conversation() async_ aware, closes #742

This commit is contained in:
Simon Willison 2025-02-16 20:19:38 -08:00
parent 3445f9a112
commit 7bf1ea665e
4 changed files with 97 additions and 16 deletions

View file

@ -4,6 +4,7 @@ from .errors import (
NeedsKeyException,
)
from .models import (
AsyncConversation,
AsyncKeyModel,
AsyncModel,
AsyncResponse,
@ -29,6 +30,7 @@ import pathlib
import struct
__all__ = [
"AsyncConversation",
"AsyncKeyModel",
"AsyncResponse",
"Attachment",

View file

@ -7,6 +7,7 @@ import json
import re
from llm import (
Attachment,
AsyncConversation,
AsyncKeyModel,
AsyncResponse,
Collection,
@ -32,6 +33,7 @@ from llm import (
set_default_embedding_model,
remove_alias,
)
from llm.models import _BaseConversation
from .migrations import migrate
from .plugins import pm, load_plugins
@ -363,7 +365,7 @@ def prompt(
if conversation_id or _continue:
# Load the conversation - loads most recent if no ID provided
try:
conversation = load_conversation(conversation_id)
conversation = load_conversation(conversation_id, async_=async_)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
@ -656,7 +658,9 @@ def chat(
print("")
def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
def load_conversation(
conversation_id: Optional[str], async_=False
) -> Optional[_BaseConversation]:
db = sqlite_utils.Database(logs_db_path())
migrate(db)
if conversation_id is None:
@ -673,11 +677,13 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
"No conversation found with id={}".format(conversation_id)
)
# Inflate that conversation
conversation = Conversation.from_row(row)
conversation_class = AsyncConversation if async_ else Conversation
response_class = AsyncResponse if async_ else Response
conversation = conversation_class.from_row(row)
for response in db["responses"].rows_where(
"conversation_id = ?", [conversation_id]
):
conversation.responses.append(Response.from_row(db, response))
conversation.responses.append(response_class.from_row(db, response))
return conversation

View file

@ -131,14 +131,9 @@ class _BaseConversation:
responses: List["_BaseResponse"] = field(default_factory=list)
@classmethod
def from_row(cls, row):
from llm import get_model
return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
)
@abstractmethod
def from_row(cls, row: Any) -> "_BaseConversation":
raise NotImplementedError
@dataclass
@ -167,6 +162,16 @@ class Conversation(_BaseConversation):
key=key,
)
@classmethod
def from_row(cls, row):
from llm import get_model
return cls(
model=get_model(row["model"]),
id=row["id"],
name=row["name"],
)
def __repr__(self):
count = len(self.responses)
s = "s" if count == 1 else ""
@ -199,6 +204,16 @@ class AsyncConversation(_BaseConversation):
key=key,
)
@classmethod
def from_row(cls, row):
from llm import get_async_model
return cls(
model=get_async_model(row["model"]),
id=row["id"],
name=row["name"],
)
def __repr__(self):
count = len(self.responses)
s = "s" if count == 1 else ""
@ -251,10 +266,13 @@ class _BaseResponse:
self.token_details = details
@classmethod
def from_row(cls, db, row):
from llm import get_model
def from_row(cls, db, row, _async=False):
from llm import get_model, get_async_model
model = get_model(row["model"])
if _async:
model = get_async_model(row["model"])
else:
model = get_model(row["model"])
response = cls(
model=model,
@ -427,7 +445,7 @@ class Response(_BaseResponse):
yield chunk
self._chunks.append(chunk)
else:
raise ValueError("self.model must be a Model or KeyModel")
raise Exception("self.model must be a Model or KeyModel")
if self.conversation:
self.conversation.responses.append(self)
@ -446,6 +464,10 @@ class AsyncResponse(_BaseResponse):
model: "AsyncModel"
conversation: Optional["AsyncConversation"] = None
@classmethod
def from_row(cls, db, row, _async=False):
return super().from_row(db, row, _async=True)
async def on_done(self, callback):
if not self._done:
self.done_callbacks.append(callback)

View file

@ -398,6 +398,57 @@ def test_llm_default_prompt(
)
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize("async_", (False, True))
def test_llm_prompt_continue(httpx_mock, user_path, async_):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [{"message": {"content": "Terry"}}],
},
headers={"Content-Type": "application/json"},
)
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["responses"].delete_where()
# First prompt
runner = CliRunner()
args = ["three names \nfor a pet pelican", "--no-stream"] + (
["--async"] if async_ else []
)
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0, result.output
assert result.output == "Bob, Alice, Eve\n"
# Should be logged
rows = list(log_db["responses"].rows)
assert len(rows) == 1
# Now ask a follow-up
args2 = ["one more", "-c", "--no-stream"] + (["--async"] if async_ else [])
result2 = runner.invoke(cli, args2, catch_exceptions=False)
assert result2.exit_code == 0, result2.output
assert result2.output == "Terry\n"
rows = list(log_db["responses"].rows)
assert len(rows) == 2
@pytest.mark.parametrize(
"args,expect_just_code",
(