From 7bf1ea665eeaf39f2db4d961273f927faf82f36d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 16 Feb 2025 20:19:38 -0800 Subject: [PATCH] Made load_conversation() async_ aware, closes #742 --- llm/__init__.py | 2 ++ llm/cli.py | 14 +++++++++---- llm/models.py | 46 +++++++++++++++++++++++++++++++----------- tests/test_llm.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 16 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index dc38382..5b18c10 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,6 +4,7 @@ from .errors import ( NeedsKeyException, ) from .models import ( + AsyncConversation, AsyncKeyModel, AsyncModel, AsyncResponse, @@ -29,6 +30,7 @@ import pathlib import struct __all__ = [ + "AsyncConversation", "AsyncKeyModel", "AsyncResponse", "Attachment", diff --git a/llm/cli.py b/llm/cli.py index 9d01b11..e821c28 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -7,6 +7,7 @@ import json import re from llm import ( Attachment, + AsyncConversation, AsyncKeyModel, AsyncResponse, Collection, @@ -32,6 +33,7 @@ from llm import ( set_default_embedding_model, remove_alias, ) +from llm.models import _BaseConversation from .migrations import migrate from .plugins import pm, load_plugins @@ -363,7 +365,7 @@ def prompt( if conversation_id or _continue: # Load the conversation - loads most recent if no ID provided try: - conversation = load_conversation(conversation_id) + conversation = load_conversation(conversation_id, async_=async_) except UnknownModelError as ex: raise click.ClickException(str(ex)) @@ -656,7 +658,9 @@ def chat( print("") -def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]: +def load_conversation( + conversation_id: Optional[str], async_=False +) -> Optional[_BaseConversation]: db = sqlite_utils.Database(logs_db_path()) migrate(db) if conversation_id is None: @@ -673,11 +677,13 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]: "No conversation found with id={}".format(conversation_id) ) # Inflate that conversation - conversation = Conversation.from_row(row) + conversation_class = AsyncConversation if async_ else Conversation + response_class = AsyncResponse if async_ else Response + conversation = conversation_class.from_row(row) for response in db["responses"].rows_where( "conversation_id = ?", [conversation_id] ): - conversation.responses.append(Response.from_row(db, response)) + conversation.responses.append(response_class.from_row(db, response)) return conversation diff --git a/llm/models.py b/llm/models.py index de6857b..c7e3c72 100644 --- a/llm/models.py +++ b/llm/models.py @@ -131,14 +131,9 @@ class _BaseConversation: responses: List["_BaseResponse"] = field(default_factory=list) @classmethod - def from_row(cls, row): - from llm import get_model - - return cls( - model=get_model(row["model"]), - id=row["id"], - name=row["name"], - ) + @abstractmethod + def from_row(cls, row: Any) -> "_BaseConversation": + raise NotImplementedError @dataclass @@ -167,6 +162,16 @@ class Conversation(_BaseConversation): key=key, ) + @classmethod + def from_row(cls, row): + from llm import get_model + + return cls( + model=get_model(row["model"]), + id=row["id"], + name=row["name"], + ) + def __repr__(self): count = len(self.responses) s = "s" if count == 1 else "" @@ -199,6 +204,16 @@ class AsyncConversation(_BaseConversation): key=key, ) + @classmethod + def from_row(cls, row): + from llm import get_async_model + + return cls( + model=get_async_model(row["model"]), + id=row["id"], + name=row["name"], + ) + def __repr__(self): count = len(self.responses) s = "s" if count == 1 else "" @@ -251,10 +266,13 @@ class _BaseResponse: self.token_details = details @classmethod - def from_row(cls, db, row): - from llm import get_model + def from_row(cls, db, row, _async=False): + from llm import get_model, get_async_model - model = get_model(row["model"]) + if _async: + model = get_async_model(row["model"]) + else: + model = get_model(row["model"]) response = cls( model=model, @@ -427,7 +445,7 @@ class Response(_BaseResponse): yield chunk self._chunks.append(chunk) else: - raise ValueError("self.model must be a Model or KeyModel") + raise Exception("self.model must be a Model or KeyModel") if self.conversation: self.conversation.responses.append(self) @@ -446,6 +464,10 @@ class AsyncResponse(_BaseResponse): model: "AsyncModel" conversation: Optional["AsyncConversation"] = None + @classmethod + def from_row(cls, db, row, _async=False): + return super().from_row(db, row, _async=True) + async def on_done(self, callback): if not self._done: self.done_callbacks.append(callback) diff --git a/tests/test_llm.py b/tests/test_llm.py index 7ff2524..4b7e7ee 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -398,6 +398,57 @@ def test_llm_default_prompt( ) +@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) +@pytest.mark.parametrize("async_", (False, True)) +def test_llm_prompt_continue(httpx_mock, user_path, async_): + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "model": "gpt-4o-mini", + "usage": {}, + "choices": [{"message": {"content": "Bob, Alice, Eve"}}], + }, + headers={"Content-Type": "application/json"}, + ) + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "model": "gpt-4o-mini", + "usage": {}, + "choices": [{"message": {"content": "Terry"}}], + }, + headers={"Content-Type": "application/json"}, + ) + + log_path = user_path / "logs.db" + log_db = sqlite_utils.Database(str(log_path)) + log_db["responses"].delete_where() + + # First prompt + runner = CliRunner() + args = ["three names \nfor a pet pelican", "--no-stream"] + ( + ["--async"] if async_ else [] + ) + result = runner.invoke(cli, args, catch_exceptions=False) + assert result.exit_code == 0, result.output + assert result.output == "Bob, Alice, Eve\n" + + # Should be logged + rows = list(log_db["responses"].rows) + assert len(rows) == 1 + + # Now ask a follow-up + args2 = ["one more", "-c", "--no-stream"] + (["--async"] if async_ else []) + result2 = runner.invoke(cli, args2, catch_exceptions=False) + assert result2.exit_code == 0, result2.output + assert result2.output == "Terry\n" + + rows = list(log_db["responses"].rows) + assert len(rows) == 2 + + @pytest.mark.parametrize( "args,expect_just_code", (