diff --git a/llm/cli.py b/llm/cli.py index 73847e1..38e1e33 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -908,10 +908,12 @@ def prompt( err=True, ) - # Log to the database + # Log responses to the database if (logs_on() or log) and not no_log: + # Could be Response, AsyncResponse, ChainResponse, AsyncChainResponse if isinstance(response, AsyncResponse): response = asyncio.run(response.to_sync_response()) + # At this point ALL forms should have a log_to_db() method that works: response.log_to_db(db) diff --git a/llm/models.py b/llm/models.py index 9afbe97..bebe0a8 100644 --- a/llm/models.py +++ b/llm/models.py @@ -557,6 +557,16 @@ class AsyncConversation(_BaseConversation): key=key, ) + def to_sync_conversation(self): + return Conversation( + model=self.model, + id=self.id, + name=self.name, + responses=[], # Because we only use this in logging + tools=self.tools, + chain_limit=self.chain_limit, + ) + @classmethod def from_row(cls, row): from llm import get_async_model @@ -1403,10 +1413,7 @@ class AsyncResponse(_BaseResponse): self.stream, # conversation type needs to be compatible too. conversation=( - self.conversation.to_sync_conversation() - if self.conversation - and hasattr(self.conversation, "to_sync_conversation") - else None + self.conversation.to_sync_conversation() if self.conversation else None ), ) response.id = self.id diff --git a/tests/test_tools.py b/tests/test_tools.py index e803a49..6fd7070 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -393,7 +393,7 @@ def trigger_error(msg: str): """ -@pytest.mark.parametrize("async_", [False]) # Add True again +@pytest.mark.parametrize("async_", (False, True)) def test_tool_errors(async_): # https://github.com/simonw/llm/issues/1107 runner = CliRunner()