llm/tests/test_llm.py

91 lines
2.6 KiB
Python
Raw Normal View History

2023-04-01 21:28:24 +00:00
from click.testing import CliRunner
2023-04-01 22:00:16 +00:00
from llm.cli import cli
from llm.migrations import migrate
2023-04-02 01:52:46 +00:00
import json
import os
2023-04-02 01:52:46 +00:00
import pytest
import sqlite_utils
from unittest import mock
2023-04-01 21:28:24 +00:00
2023-04-01 22:00:16 +00:00
def test_version():
2023-04-01 21:28:24 +00:00
runner = CliRunner()
with runner.isolated_filesystem():
2023-04-01 22:00:16 +00:00
result = runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert result.output.startswith("cli, version ")
2023-04-02 01:52:46 +00:00
@pytest.mark.parametrize("n", (None, 0, 2))
def test_logs(n, log_path):
db = sqlite_utils.Database(str(log_path))
migrate(db)
2023-04-02 01:52:46 +00:00
db["log"].insert_all(
{
"system": "system",
"prompt": "prompt",
"response": "response",
"model": "davinci",
}
for i in range(100)
)
runner = CliRunner()
args = ["logs", "-p", str(log_path)]
2023-04-02 01:52:46 +00:00
if n is not None:
args.extend(["-n", str(n)])
result = runner.invoke(cli, args, catch_exceptions=False)
2023-04-02 01:52:46 +00:00
assert result.exit_code == 0
logs = json.loads(result.output)
expected_length = 3
if n is not None:
if n == 0:
expected_length = 100
else:
expected_length = n
assert len(logs) == expected_length
2023-06-17 08:29:36 +00:00
@pytest.mark.parametrize("env", ({}, {"LLM_LOG_PATH": "/tmp/logs.db"}))
def test_logs_path(monkeypatch, env, log_path):
for key, value in env.items():
monkeypatch.setenv(key, value)
runner = CliRunner()
result = runner.invoke(cli, ["logs", "path"])
assert result.exit_code == 0
if env:
expected = env["LLM_LOG_PATH"]
else:
expected = str(log_path)
assert result.output.strip() == expected
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize("use_stdin", (True, False))
2023-06-16 07:47:15 +00:00
def test_llm_default_prompt(mocked_openai, use_stdin, log_path):
# Reset the log_path database
log_db = sqlite_utils.Database(str(log_path))
log_db["log"].delete_where()
runner = CliRunner()
prompt = "three names for a pet pelican"
input = None
args = ["--no-stream"]
if use_stdin:
input = prompt
else:
args.append(prompt)
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
2023-06-16 07:47:15 +00:00
assert mocked_openai.last_request.headers["Authorization"] == "Bearer X"
# Was it logged?
rows = list(log_db["log"].rows)
assert len(rows) == 1
expected = {
"model": "gpt-3.5-turbo",
"prompt": "three names for a pet pelican",
"system": None,
"response": "Bob, Alice, Eve",
"chat_id": None,
}
assert expected.items() <= rows[0].items()