Fixes --continue bug and adds --database argument to llm chat

* Fix database bug in continue conversation and adds --database to llm chat
* Move the --database to proper place and update help.

Closes #933
This commit is contained in:
Sukhbinder Singh 2025-05-05 04:27:13 +05:30 committed by GitHub
parent 00e5ee6b5a
commit 8c7d33ee52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 101 additions and 5 deletions

View file

@ -162,6 +162,7 @@ Options:
-t, --template TEXT Template to use
-p, --param <TEXT TEXT>... Parameters for template
-o, --option <TEXT TEXT>... key/value options for the model
-d, --database FILE Path to log database
--no-stream Do not stream output
--key TEXT API key to use
--help Show this message and exit.

View file

@ -637,7 +637,9 @@ def prompt(
if conversation_id or _continue:
# Load the conversation - loads most recent if no ID provided
try:
conversation = load_conversation(conversation_id, async_=async_)
conversation = load_conversation(
conversation_id, async_=async_, database=database
)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
@ -834,6 +836,12 @@ def prompt(
multiple=True,
help="key/value options for the model",
)
@click.option(
"-d",
"--database",
type=click.Path(readable=True, dir_okay=False),
help="Path to log database",
)
@click.option("--no-stream", is_flag=True, help="Do not stream output")
@click.option("--key", help="API key to use")
def chat(
@ -846,6 +854,7 @@ def chat(
options,
no_stream,
key,
database,
):
"""
Hold an ongoing chat with a model.
@ -857,7 +866,7 @@ def chat(
else:
readline.parse_and_bind("bind -x '\\e[D: backward-char'")
readline.parse_and_bind("bind -x '\\e[C: forward-char'")
log_path = logs_db_path()
log_path = pathlib.Path(database) if database else logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
@ -866,7 +875,7 @@ def chat(
if conversation_id or _continue:
# Load the conversation - loads most recent if no ID provided
try:
conversation = load_conversation(conversation_id)
conversation = load_conversation(conversation_id, database=database)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
@ -979,9 +988,12 @@ def chat(
def load_conversation(
conversation_id: Optional[str], async_=False
conversation_id: Optional[str],
async_=False,
database=None,
) -> Optional[_BaseConversation]:
db = sqlite_utils.Database(logs_db_path())
log_path = pathlib.Path(database) if database else logs_db_path()
db = sqlite_utils.Database(log_path)
migrate(db)
if conversation_id is None:
# Return the most recent conversation, or None if there are none

View file

@ -3,6 +3,7 @@ import llm.cli
from unittest.mock import ANY
import pytest
import sys
import sqlite_utils
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
@ -238,3 +239,29 @@ def test_chat_multi(mock_model, logs_db, input, expected):
assert result.exit_code == 0
rows = list(logs_db["responses"].rows_where(select="prompt, response"))
assert rows == expected
@pytest.mark.parametrize("custom_database_path", (False, True))
def test_llm_chat_creates_log_database(tmpdir, monkeypatch, custom_database_path):
user_path = tmpdir / "user"
custom_db_path = tmpdir / "custom_log.db"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
runner = CliRunner()
args = ["chat", "-m", "mock"]
if custom_database_path:
args.extend(["--database", str(custom_db_path)])
result = runner.invoke(
llm.cli.cli,
args,
catch_exceptions=False,
input="Hi\nHi two\nquit\n",
)
assert result.exit_code == 0
# Should have created user_path and put a logs.db in it
if custom_database_path:
assert custom_db_path.exists()
db_path = str(custom_db_path)
else:
assert (user_path / "logs.db").exists()
db_path = str(user_path / "logs.db")
assert sqlite_utils.Database(db_path)["responses"].count == 2

View file

@ -789,3 +789,59 @@ def test_schemas_dsl():
},
"required": ["items"],
}
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize("custom_database_path", (False, True))
def test_llm_prompt_continue_with_database(
tmpdir, monkeypatch, httpx_mock, user_path, custom_database_path
):
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)
httpx_mock.add_response(
method="POST",
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4o-mini",
"usage": {},
"choices": [{"message": {"content": "Terry"}}],
},
headers={"Content-Type": "application/json"},
)
user_path = tmpdir / "user"
custom_db_path = tmpdir / "custom_log.db"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
# First prompt
runner = CliRunner()
args = ["three names \nfor a pet pelican", "--no-stream"]
if custom_database_path:
args.extend(["--database", str(custom_db_path)])
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0, result.output
assert result.output == "Bob, Alice, Eve\n"
# Now ask a follow-up
args2 = ["one more", "-c", "--no-stream"]
if custom_database_path:
args2.extend(["--database", str(custom_db_path)])
result2 = runner.invoke(cli, args2, catch_exceptions=False)
assert result2.exit_code == 0, result2.output
assert result2.output == "Terry\n"
if custom_database_path:
assert custom_db_path.exists()
db_path = str(custom_db_path)
else:
assert (user_path / "logs.db").exists()
db_path = str(user_path / "logs.db")
assert sqlite_utils.Database(db_path)["responses"].count == 2