mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Fixes --continue bug and adds --database argument to llm chat
* Fix database bug in continue conversation and adds --database to llm chat * Move the --database to proper place and update help. Closes #933
This commit is contained in:
parent
00e5ee6b5a
commit
8c7d33ee52
4 changed files with 101 additions and 5 deletions
|
|
@ -162,6 +162,7 @@ Options:
|
|||
-t, --template TEXT Template to use
|
||||
-p, --param <TEXT TEXT>... Parameters for template
|
||||
-o, --option <TEXT TEXT>... key/value options for the model
|
||||
-d, --database FILE Path to log database
|
||||
--no-stream Do not stream output
|
||||
--key TEXT API key to use
|
||||
--help Show this message and exit.
|
||||
|
|
|
|||
22
llm/cli.py
22
llm/cli.py
|
|
@ -637,7 +637,9 @@ def prompt(
|
|||
if conversation_id or _continue:
|
||||
# Load the conversation - loads most recent if no ID provided
|
||||
try:
|
||||
conversation = load_conversation(conversation_id, async_=async_)
|
||||
conversation = load_conversation(
|
||||
conversation_id, async_=async_, database=database
|
||||
)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
|
|
@ -834,6 +836,12 @@ def prompt(
|
|||
multiple=True,
|
||||
help="key/value options for the model",
|
||||
)
|
||||
@click.option(
|
||||
"-d",
|
||||
"--database",
|
||||
type=click.Path(readable=True, dir_okay=False),
|
||||
help="Path to log database",
|
||||
)
|
||||
@click.option("--no-stream", is_flag=True, help="Do not stream output")
|
||||
@click.option("--key", help="API key to use")
|
||||
def chat(
|
||||
|
|
@ -846,6 +854,7 @@ def chat(
|
|||
options,
|
||||
no_stream,
|
||||
key,
|
||||
database,
|
||||
):
|
||||
"""
|
||||
Hold an ongoing chat with a model.
|
||||
|
|
@ -857,7 +866,7 @@ def chat(
|
|||
else:
|
||||
readline.parse_and_bind("bind -x '\\e[D: backward-char'")
|
||||
readline.parse_and_bind("bind -x '\\e[C: forward-char'")
|
||||
log_path = logs_db_path()
|
||||
log_path = pathlib.Path(database) if database else logs_db_path()
|
||||
(log_path.parent).mkdir(parents=True, exist_ok=True)
|
||||
db = sqlite_utils.Database(log_path)
|
||||
migrate(db)
|
||||
|
|
@ -866,7 +875,7 @@ def chat(
|
|||
if conversation_id or _continue:
|
||||
# Load the conversation - loads most recent if no ID provided
|
||||
try:
|
||||
conversation = load_conversation(conversation_id)
|
||||
conversation = load_conversation(conversation_id, database=database)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
|
|
@ -979,9 +988,12 @@ def chat(
|
|||
|
||||
|
||||
def load_conversation(
|
||||
conversation_id: Optional[str], async_=False
|
||||
conversation_id: Optional[str],
|
||||
async_=False,
|
||||
database=None,
|
||||
) -> Optional[_BaseConversation]:
|
||||
db = sqlite_utils.Database(logs_db_path())
|
||||
log_path = pathlib.Path(database) if database else logs_db_path()
|
||||
db = sqlite_utils.Database(log_path)
|
||||
migrate(db)
|
||||
if conversation_id is None:
|
||||
# Return the most recent conversation, or None if there are none
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import llm.cli
|
|||
from unittest.mock import ANY
|
||||
import pytest
|
||||
import sys
|
||||
import sqlite_utils
|
||||
|
||||
|
||||
@pytest.mark.xfail(sys.platform == "win32", reason="Expected to fail on Windows")
|
||||
|
|
@ -238,3 +239,29 @@ def test_chat_multi(mock_model, logs_db, input, expected):
|
|||
assert result.exit_code == 0
|
||||
rows = list(logs_db["responses"].rows_where(select="prompt, response"))
|
||||
assert rows == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_database_path", (False, True))
|
||||
def test_llm_chat_creates_log_database(tmpdir, monkeypatch, custom_database_path):
|
||||
user_path = tmpdir / "user"
|
||||
custom_db_path = tmpdir / "custom_log.db"
|
||||
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
|
||||
runner = CliRunner()
|
||||
args = ["chat", "-m", "mock"]
|
||||
if custom_database_path:
|
||||
args.extend(["--database", str(custom_db_path)])
|
||||
result = runner.invoke(
|
||||
llm.cli.cli,
|
||||
args,
|
||||
catch_exceptions=False,
|
||||
input="Hi\nHi two\nquit\n",
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
# Should have created user_path and put a logs.db in it
|
||||
if custom_database_path:
|
||||
assert custom_db_path.exists()
|
||||
db_path = str(custom_db_path)
|
||||
else:
|
||||
assert (user_path / "logs.db").exists()
|
||||
db_path = str(user_path / "logs.db")
|
||||
assert sqlite_utils.Database(db_path)["responses"].count == 2
|
||||
|
|
|
|||
|
|
@ -789,3 +789,59 @@ def test_schemas_dsl():
|
|||
},
|
||||
"required": ["items"],
|
||||
}
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
||||
@pytest.mark.parametrize("custom_database_path", (False, True))
|
||||
def test_llm_prompt_continue_with_database(
|
||||
tmpdir, monkeypatch, httpx_mock, user_path, custom_database_path
|
||||
):
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {},
|
||||
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {},
|
||||
"choices": [{"message": {"content": "Terry"}}],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
user_path = tmpdir / "user"
|
||||
custom_db_path = tmpdir / "custom_log.db"
|
||||
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
|
||||
|
||||
# First prompt
|
||||
runner = CliRunner()
|
||||
args = ["three names \nfor a pet pelican", "--no-stream"]
|
||||
if custom_database_path:
|
||||
args.extend(["--database", str(custom_db_path)])
|
||||
result = runner.invoke(cli, args, catch_exceptions=False)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert result.output == "Bob, Alice, Eve\n"
|
||||
|
||||
# Now ask a follow-up
|
||||
args2 = ["one more", "-c", "--no-stream"]
|
||||
if custom_database_path:
|
||||
args2.extend(["--database", str(custom_db_path)])
|
||||
result2 = runner.invoke(cli, args2, catch_exceptions=False)
|
||||
assert result2.exit_code == 0, result2.output
|
||||
assert result2.output == "Terry\n"
|
||||
|
||||
if custom_database_path:
|
||||
assert custom_db_path.exists()
|
||||
db_path = str(custom_db_path)
|
||||
else:
|
||||
assert (user_path / "logs.db").exists()
|
||||
db_path = str(user_path / "logs.db")
|
||||
assert sqlite_utils.Database(db_path)["responses"].count == 2
|
||||
|
|
|
|||
Loading…
Reference in a new issue