From 8c7d33ee522adb094643ca06f15fd5172451b1a5 Mon Sep 17 00:00:00 2001 From: Sukhbinder Singh Date: Mon, 5 May 2025 04:27:13 +0530 Subject: [PATCH] Fixes `--continue` bug and adds `--database` argument to `llm chat` * Fix database bug in continue conversation and adds --database to llm chat * Move the --database to proper place and update help. Closes #933 --- docs/help.md | 1 + llm/cli.py | 22 +++++++++++++----- tests/test_chat.py | 27 ++++++++++++++++++++++ tests/test_llm.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) diff --git a/docs/help.md b/docs/help.md index b05a525..6840a47 100644 --- a/docs/help.md +++ b/docs/help.md @@ -162,6 +162,7 @@ Options: -t, --template TEXT Template to use -p, --param ... Parameters for template -o, --option ... key/value options for the model + -d, --database FILE Path to log database --no-stream Do not stream output --key TEXT API key to use --help Show this message and exit. diff --git a/llm/cli.py b/llm/cli.py index ca66838..09e4e2d 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -637,7 +637,9 @@ def prompt( if conversation_id or _continue: # Load the conversation - loads most recent if no ID provided try: - conversation = load_conversation(conversation_id, async_=async_) + conversation = load_conversation( + conversation_id, async_=async_, database=database + ) except UnknownModelError as ex: raise click.ClickException(str(ex)) @@ -834,6 +836,12 @@ def prompt( multiple=True, help="key/value options for the model", ) +@click.option( + "-d", + "--database", + type=click.Path(readable=True, dir_okay=False), + help="Path to log database", +) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("--key", help="API key to use") def chat( @@ -846,6 +854,7 @@ def chat( options, no_stream, key, + database, ): """ Hold an ongoing chat with a model. @@ -857,7 +866,7 @@ def chat( else: readline.parse_and_bind("bind -x '\\e[D: backward-char'") readline.parse_and_bind("bind -x '\\e[C: forward-char'") - log_path = logs_db_path() + log_path = pathlib.Path(database) if database else logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) migrate(db) @@ -866,7 +875,7 @@ def chat( if conversation_id or _continue: # Load the conversation - loads most recent if no ID provided try: - conversation = load_conversation(conversation_id) + conversation = load_conversation(conversation_id, database=database) except UnknownModelError as ex: raise click.ClickException(str(ex)) @@ -979,9 +988,12 @@ def chat( def load_conversation( - conversation_id: Optional[str], async_=False + conversation_id: Optional[str], + async_=False, + database=None, ) -> Optional[_BaseConversation]: - db = sqlite_utils.Database(logs_db_path()) + log_path = pathlib.Path(database) if database else logs_db_path() + db = sqlite_utils.Database(log_path) migrate(db) if conversation_id is None: # Return the most recent conversation, or None if there are none diff --git a/tests/test_chat.py b/tests/test_chat.py index 89fc102..278a3f2 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -3,6 +3,7 @@ import llm.cli from unittest.mock import ANY import pytest import sys +import sqlite_utils @pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows") @@ -238,3 +239,29 @@ def test_chat_multi(mock_model, logs_db, input, expected): assert result.exit_code == 0 rows = list(logs_db["responses"].rows_where(select="prompt, response")) assert rows == expected + + +@pytest.mark.parametrize("custom_database_path", (False, True)) +def test_llm_chat_creates_log_database(tmpdir, monkeypatch, custom_database_path): + user_path = tmpdir / "user" + custom_db_path = tmpdir / "custom_log.db" + monkeypatch.setenv("LLM_USER_PATH", str(user_path)) + runner = CliRunner() + args = ["chat", "-m", "mock"] + if custom_database_path: + args.extend(["--database", str(custom_db_path)]) + result = runner.invoke( + llm.cli.cli, + args, + catch_exceptions=False, + input="Hi\nHi two\nquit\n", + ) + assert result.exit_code == 0 + # Should have created user_path and put a logs.db in it + if custom_database_path: + assert custom_db_path.exists() + db_path = str(custom_db_path) + else: + assert (user_path / "logs.db").exists() + db_path = str(user_path / "logs.db") + assert sqlite_utils.Database(db_path)["responses"].count == 2 diff --git a/tests/test_llm.py b/tests/test_llm.py index 4f273b3..661e116 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -789,3 +789,59 @@ def test_schemas_dsl(): }, "required": ["items"], } + + +@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) +@pytest.mark.parametrize("custom_database_path", (False, True)) +def test_llm_prompt_continue_with_database( + tmpdir, monkeypatch, httpx_mock, user_path, custom_database_path +): + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "model": "gpt-4o-mini", + "usage": {}, + "choices": [{"message": {"content": "Bob, Alice, Eve"}}], + }, + headers={"Content-Type": "application/json"}, + ) + httpx_mock.add_response( + method="POST", + url="https://api.openai.com/v1/chat/completions", + json={ + "model": "gpt-4o-mini", + "usage": {}, + "choices": [{"message": {"content": "Terry"}}], + }, + headers={"Content-Type": "application/json"}, + ) + + user_path = tmpdir / "user" + custom_db_path = tmpdir / "custom_log.db" + monkeypatch.setenv("LLM_USER_PATH", str(user_path)) + + # First prompt + runner = CliRunner() + args = ["three names \nfor a pet pelican", "--no-stream"] + if custom_database_path: + args.extend(["--database", str(custom_db_path)]) + result = runner.invoke(cli, args, catch_exceptions=False) + assert result.exit_code == 0, result.output + assert result.output == "Bob, Alice, Eve\n" + + # Now ask a follow-up + args2 = ["one more", "-c", "--no-stream"] + if custom_database_path: + args2.extend(["--database", str(custom_db_path)]) + result2 = runner.invoke(cli, args2, catch_exceptions=False) + assert result2.exit_code == 0, result2.output + assert result2.output == "Terry\n" + + if custom_database_path: + assert custom_db_path.exists() + db_path = str(custom_db_path) + else: + assert (user_path / "logs.db").exists() + db_path = str(user_path / "logs.db") + assert sqlite_utils.Database(db_path)["responses"].count == 2