mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-28 02:10:27 +00:00
90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
from click.testing import CliRunner
|
|
from llm.cli import cli
|
|
from llm.migrations import migrate
|
|
import json
|
|
import os
|
|
import pytest
|
|
import sqlite_utils
|
|
from unittest import mock
|
|
|
|
|
|
def test_version():
|
|
runner = CliRunner()
|
|
with runner.isolated_filesystem():
|
|
result = runner.invoke(cli, ["--version"])
|
|
assert result.exit_code == 0
|
|
assert result.output.startswith("cli, version ")
|
|
|
|
|
|
@pytest.mark.parametrize("n", (None, 0, 2))
|
|
def test_logs(n, log_path):
|
|
db = sqlite_utils.Database(str(log_path))
|
|
migrate(db)
|
|
db["log"].insert_all(
|
|
{
|
|
"system": "system",
|
|
"prompt": "prompt",
|
|
"response": "response",
|
|
"model": "davinci",
|
|
}
|
|
for i in range(100)
|
|
)
|
|
runner = CliRunner()
|
|
args = ["logs", "-p", str(log_path)]
|
|
if n is not None:
|
|
args.extend(["-n", str(n)])
|
|
result = runner.invoke(cli, args, catch_exceptions=False)
|
|
assert result.exit_code == 0
|
|
logs = json.loads(result.output)
|
|
expected_length = 3
|
|
if n is not None:
|
|
if n == 0:
|
|
expected_length = 100
|
|
else:
|
|
expected_length = n
|
|
assert len(logs) == expected_length
|
|
|
|
|
|
@pytest.mark.parametrize("env", ({}, {"LLM_LOG_PATH": "/tmp/log.db"}))
|
|
def test_logs_path(monkeypatch, env, log_path):
|
|
for key, value in env.items():
|
|
monkeypatch.setenv(key, value)
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["logs", "path"])
|
|
assert result.exit_code == 0
|
|
if env:
|
|
expected = env["LLM_LOG_PATH"]
|
|
else:
|
|
expected = str(log_path)
|
|
assert result.output.strip() == expected
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
|
@pytest.mark.parametrize("use_stdin", (True, False))
|
|
def test_llm_default_prompt(mocked_openai, use_stdin, log_path):
|
|
# Reset the log_path database
|
|
log_db = sqlite_utils.Database(str(log_path))
|
|
log_db["log"].delete_where()
|
|
runner = CliRunner()
|
|
prompt = "three names for a pet pelican"
|
|
input = None
|
|
args = ["--no-stream"]
|
|
if use_stdin:
|
|
input = prompt
|
|
else:
|
|
args.append(prompt)
|
|
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
|
|
assert result.exit_code == 0
|
|
assert result.output == "Bob, Alice, Eve\n"
|
|
assert mocked_openai.last_request.headers["Authorization"] == "Bearer X"
|
|
# Was it logged?
|
|
rows = list(log_db["log"].rows)
|
|
assert len(rows) == 1
|
|
expected = {
|
|
"model": "gpt-3.5-turbo",
|
|
"prompt": "three names for a pet pelican",
|
|
"system": None,
|
|
"response": "Bob, Alice, Eve",
|
|
"chat_id": None,
|
|
}
|
|
assert expected.items() <= rows[0].items()
|