Store debug info, closes #34

This commit is contained in:
Simon Willison 2023-06-16 08:47:15 +01:00
parent 43f2cbd2e3
commit 8308fe5cbf
6 changed files with 63 additions and 39 deletions

View file

@ -8,6 +8,7 @@ import os
import pathlib
import sqlite_utils
import sys
import time
import warnings
warnings.simplefilter("ignore", ResourceWarning)
@ -79,28 +80,34 @@ def prompt(prompt, system, model, no_stream, no_log, _continue, chat_id, key):
# Resolve model aliases
model = MODEL_ALIASES.get(model, model)
try:
debug = {}
if no_stream:
start = time.time()
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
debug["model"] = response.model
debug["usage"] = response.usage
content = response.choices[0].message.content
log(no_log, system, prompt, content, model, chat_id)
log(no_log, system, prompt, content, model, chat_id, debug, start)
print(content)
else:
start = time.time()
response = []
for chunk in openai.ChatCompletion.create(
model=model,
messages=messages,
stream=True,
):
debug["model"] = chunk.model
content = chunk["choices"][0].get("delta", {}).get("content")
if content is not None:
response.append(content)
print(content, end="")
sys.stdout.flush()
print("")
log(no_log, system, prompt, "".join(response), model, chat_id)
log(no_log, system, prompt, "".join(response), model, chat_id, debug, start)
except openai.error.AuthenticationError as ex:
raise click.ClickException("{}: {}".format(ex.error.type, ex.error.code))
except openai.error.OpenAIError as ex:
@ -251,7 +258,11 @@ def log_db_path():
return user_dir() / "log.db"
def log(no_log, system, prompt, response, model, chat_id=None):
def log(no_log, system, prompt, response, model, chat_id=None, debug=None, start=None):
duration_ms = None
if start is not None:
end = time.time()
duration_ms = int((end - start) * 1000)
if no_log:
return
log_path = log_db_path()
@ -267,6 +278,8 @@ def log(no_log, system, prompt, response, model, chat_id=None):
"response": response,
"model": model,
"timestamp": str(datetime.datetime.utcnow()),
"debug": debug,
"duration_ms": duration_ms,
},
)

View file

@ -78,3 +78,9 @@ def m004_column_order(db):
@migration
def m004_drop_provider(db):
db["log"].transform(drop=("provider",))
@migration
def m005_debug(db):
db["log"].add_column("debug", str)
db["log"].add_column("duration_ms", int)

View file

@ -15,3 +15,16 @@ def keys_path(tmpdir):
def env_setup(monkeypatch, log_path, keys_path):
monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path))
monkeypatch.setenv("LLM_LOG_PATH", str(log_path))
@pytest.fixture
def mocked_openai(requests_mock):
return requests_mock.post(
"https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-3.5-turbo",
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
headers={"Content-Type": "application/json"},
)

View file

@ -31,7 +31,7 @@ def test_keys_set(monkeypatch, tmpdir):
}
def test_uses_correct_key(requests_mock, monkeypatch, tmpdir):
def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir):
keys_path = tmpdir / "keys.json"
keys_path.write_text(
json.dumps(
@ -44,14 +44,11 @@ def test_uses_correct_key(requests_mock, monkeypatch, tmpdir):
)
monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path))
monkeypatch.setenv("OPENAI_API_KEY", "from-env")
mocked = requests_mock.post(
"https://api.openai.com/v1/chat/completions",
json={"choices": [{"message": {"content": "Bob, Alice, Eve"}}]},
headers={"Content-Type": "application/json"},
)
def assert_key(key):
assert mocked.last_request.headers["Authorization"] == "Bearer {}".format(key)
assert mocked_openai.last_request.headers[
"Authorization"
] == "Bearer {}".format(key)
runner = CliRunner()
# Called without --key uses environment variable

View file

@ -61,15 +61,10 @@ def test_logs_path(monkeypatch, env, log_path):
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize("use_stdin", (True, False))
def test_llm_default_prompt(requests_mock, use_stdin, log_path):
def test_llm_default_prompt(mocked_openai, use_stdin, log_path):
# Reset the log_path database
log_db = sqlite_utils.Database(str(log_path))
log_db["log"].delete_where()
mocked = requests_mock.post(
"https://api.openai.com/v1/chat/completions",
json={"choices": [{"message": {"content": "Bob, Alice, Eve"}}]},
headers={"Content-Type": "application/json"},
)
runner = CliRunner()
prompt = "three names for a pet pelican"
input = None
@ -81,7 +76,7 @@ def test_llm_default_prompt(requests_mock, use_stdin, log_path):
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
assert mocked.last_request.headers["Authorization"] == "Bearer X"
assert mocked_openai.last_request.headers["Authorization"] == "Bearer X"
# Was it logged?
rows = list(log_db["log"].rows)
assert len(rows) == 1

View file

@ -6,17 +6,17 @@ def test_migrate_blank():
db = sqlite_utils.Database(memory=True)
migrate(db)
assert set(db.table_names()) == {"_llm_migrations", "log"}
assert db["log"].schema == (
'CREATE TABLE "log" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [model] TEXT,\n"
" [timestamp] TEXT,\n"
" [prompt] TEXT,\n"
" [system] TEXT,\n"
" [response] TEXT,\n"
" [chat_id] INTEGER REFERENCES [log]([id])\n"
")"
)
assert db["log"].columns_dict == {
"id": int,
"model": str,
"timestamp": str,
"prompt": str,
"system": str,
"response": str,
"chat_id": int,
"debug": str,
"duration_ms": int,
}
def test_migrate_from_original_schema():
@ -35,14 +35,14 @@ def test_migrate_from_original_schema():
migrate(db)
assert set(db.table_names()) == {"_llm_migrations", "log"}
schema = db["log"].schema
assert schema == (
'CREATE TABLE "log" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [model] TEXT,\n"
" [timestamp] TEXT,\n"
" [prompt] TEXT,\n"
" [system] TEXT,\n"
" [response] TEXT,\n"
" [chat_id] INTEGER REFERENCES [log]([id])\n"
")"
)
assert db["log"].columns_dict == {
"id": int,
"model": str,
"timestamp": str,
"prompt": str,
"system": str,
"response": str,
"chat_id": int,
"debug": str,
"duration_ms": int,
}