diff --git a/llm/models.py b/llm/models.py index d31faca..a45cb2b 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1325,6 +1325,7 @@ class AsyncResponse(_BaseResponse): response.response_json = self.response_json response._tool_calls = list(self._tool_calls) response.attachments = list(self.attachments) + response.resolved_model = self.resolved_model return response @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index 64ed5b6..f64dce5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,6 +109,7 @@ class AsyncMockModel(llm.AsyncModel): def __init__(self): self.history = [] self._queue = [] + self.resolved_model_name = None def enqueue(self, messages): assert isinstance(messages, list) @@ -129,6 +130,8 @@ class AsyncMockModel(llm.AsyncModel): response.set_usage( input=len((prompt.prompt or "").split()), output=len(gathered) ) + if self.resolved_model_name is not None: + response.set_resolved_model(self.resolved_model_name) class EmbedDemo(llm.EmbeddingModel): diff --git a/tests/test_llm_logs.py b/tests/test_llm_logs.py index 5c4a387..2dac375 100644 --- a/tests/test_llm_logs.py +++ b/tests/test_llm_logs.py @@ -946,10 +946,14 @@ def test_logs_backup(logs_db): assert expected_path.exists() -def test_logs_resolved_model(logs_db, mock_model): +@pytest.mark.parametrize("async_", (False, True)) +def test_logs_resolved_model(logs_db, mock_model, async_mock_model, async_): mock_model.resolved_model_name = "resolved-mock" + async_mock_model.resolved_model_name = "resolved-mock" runner = CliRunner() - result = runner.invoke(cli, ["-m", "mock", "simple prompt"]) + result = runner.invoke( + cli, ["-m", "mock", "simple prompt"] + (["--async"] if async_ else []) + ) assert result.exit_code == 0 # Should have logged the resolved model name assert logs_db["responses"].count