mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-24 08:20:25 +00:00
Store debug info, closes #34
This commit is contained in:
parent
43f2cbd2e3
commit
8308fe5cbf
6 changed files with 63 additions and 39 deletions
19
llm/cli.py
19
llm/cli.py
|
|
@ -8,6 +8,7 @@ import os
|
|||
import pathlib
|
||||
import sqlite_utils
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
|
@ -79,28 +80,34 @@ def prompt(prompt, system, model, no_stream, no_log, _continue, chat_id, key):
|
|||
# Resolve model aliases
|
||||
model = MODEL_ALIASES.get(model, model)
|
||||
try:
|
||||
debug = {}
|
||||
if no_stream:
|
||||
start = time.time()
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
debug["model"] = response.model
|
||||
debug["usage"] = response.usage
|
||||
content = response.choices[0].message.content
|
||||
log(no_log, system, prompt, content, model, chat_id)
|
||||
log(no_log, system, prompt, content, model, chat_id, debug, start)
|
||||
print(content)
|
||||
else:
|
||||
start = time.time()
|
||||
response = []
|
||||
for chunk in openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
):
|
||||
debug["model"] = chunk.model
|
||||
content = chunk["choices"][0].get("delta", {}).get("content")
|
||||
if content is not None:
|
||||
response.append(content)
|
||||
print(content, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
log(no_log, system, prompt, "".join(response), model, chat_id)
|
||||
log(no_log, system, prompt, "".join(response), model, chat_id, debug, start)
|
||||
except openai.error.AuthenticationError as ex:
|
||||
raise click.ClickException("{}: {}".format(ex.error.type, ex.error.code))
|
||||
except openai.error.OpenAIError as ex:
|
||||
|
|
@ -251,7 +258,11 @@ def log_db_path():
|
|||
return user_dir() / "log.db"
|
||||
|
||||
|
||||
def log(no_log, system, prompt, response, model, chat_id=None):
|
||||
def log(no_log, system, prompt, response, model, chat_id=None, debug=None, start=None):
|
||||
duration_ms = None
|
||||
if start is not None:
|
||||
end = time.time()
|
||||
duration_ms = int((end - start) * 1000)
|
||||
if no_log:
|
||||
return
|
||||
log_path = log_db_path()
|
||||
|
|
@ -267,6 +278,8 @@ def log(no_log, system, prompt, response, model, chat_id=None):
|
|||
"response": response,
|
||||
"model": model,
|
||||
"timestamp": str(datetime.datetime.utcnow()),
|
||||
"debug": debug,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -78,3 +78,9 @@ def m004_column_order(db):
|
|||
@migration
|
||||
def m004_drop_provider(db):
|
||||
db["log"].transform(drop=("provider",))
|
||||
|
||||
|
||||
@migration
|
||||
def m005_debug(db):
|
||||
db["log"].add_column("debug", str)
|
||||
db["log"].add_column("duration_ms", int)
|
||||
|
|
|
|||
|
|
@ -15,3 +15,16 @@ def keys_path(tmpdir):
|
|||
def env_setup(monkeypatch, log_path, keys_path):
|
||||
monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path))
|
||||
monkeypatch.setenv("LLM_LOG_PATH", str(log_path))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_openai(requests_mock):
|
||||
return requests_mock.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"usage": {},
|
||||
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def test_keys_set(monkeypatch, tmpdir):
|
|||
}
|
||||
|
||||
|
||||
def test_uses_correct_key(requests_mock, monkeypatch, tmpdir):
|
||||
def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir):
|
||||
keys_path = tmpdir / "keys.json"
|
||||
keys_path.write_text(
|
||||
json.dumps(
|
||||
|
|
@ -44,14 +44,11 @@ def test_uses_correct_key(requests_mock, monkeypatch, tmpdir):
|
|||
)
|
||||
monkeypatch.setenv("LLM_KEYS_PATH", str(keys_path))
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "from-env")
|
||||
mocked = requests_mock.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={"choices": [{"message": {"content": "Bob, Alice, Eve"}}]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def assert_key(key):
|
||||
assert mocked.last_request.headers["Authorization"] == "Bearer {}".format(key)
|
||||
assert mocked_openai.last_request.headers[
|
||||
"Authorization"
|
||||
] == "Bearer {}".format(key)
|
||||
|
||||
runner = CliRunner()
|
||||
# Called without --key uses environment variable
|
||||
|
|
|
|||
|
|
@ -61,15 +61,10 @@ def test_logs_path(monkeypatch, env, log_path):
|
|||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
||||
@pytest.mark.parametrize("use_stdin", (True, False))
|
||||
def test_llm_default_prompt(requests_mock, use_stdin, log_path):
|
||||
def test_llm_default_prompt(mocked_openai, use_stdin, log_path):
|
||||
# Reset the log_path database
|
||||
log_db = sqlite_utils.Database(str(log_path))
|
||||
log_db["log"].delete_where()
|
||||
mocked = requests_mock.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={"choices": [{"message": {"content": "Bob, Alice, Eve"}}]},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
runner = CliRunner()
|
||||
prompt = "three names for a pet pelican"
|
||||
input = None
|
||||
|
|
@ -81,7 +76,7 @@ def test_llm_default_prompt(requests_mock, use_stdin, log_path):
|
|||
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == "Bob, Alice, Eve\n"
|
||||
assert mocked.last_request.headers["Authorization"] == "Bearer X"
|
||||
assert mocked_openai.last_request.headers["Authorization"] == "Bearer X"
|
||||
# Was it logged?
|
||||
rows = list(log_db["log"].rows)
|
||||
assert len(rows) == 1
|
||||
|
|
|
|||
|
|
@ -6,17 +6,17 @@ def test_migrate_blank():
|
|||
db = sqlite_utils.Database(memory=True)
|
||||
migrate(db)
|
||||
assert set(db.table_names()) == {"_llm_migrations", "log"}
|
||||
assert db["log"].schema == (
|
||||
'CREATE TABLE "log" (\n'
|
||||
" [id] INTEGER PRIMARY KEY,\n"
|
||||
" [model] TEXT,\n"
|
||||
" [timestamp] TEXT,\n"
|
||||
" [prompt] TEXT,\n"
|
||||
" [system] TEXT,\n"
|
||||
" [response] TEXT,\n"
|
||||
" [chat_id] INTEGER REFERENCES [log]([id])\n"
|
||||
")"
|
||||
)
|
||||
assert db["log"].columns_dict == {
|
||||
"id": int,
|
||||
"model": str,
|
||||
"timestamp": str,
|
||||
"prompt": str,
|
||||
"system": str,
|
||||
"response": str,
|
||||
"chat_id": int,
|
||||
"debug": str,
|
||||
"duration_ms": int,
|
||||
}
|
||||
|
||||
|
||||
def test_migrate_from_original_schema():
|
||||
|
|
@ -35,14 +35,14 @@ def test_migrate_from_original_schema():
|
|||
migrate(db)
|
||||
assert set(db.table_names()) == {"_llm_migrations", "log"}
|
||||
schema = db["log"].schema
|
||||
assert schema == (
|
||||
'CREATE TABLE "log" (\n'
|
||||
" [id] INTEGER PRIMARY KEY,\n"
|
||||
" [model] TEXT,\n"
|
||||
" [timestamp] TEXT,\n"
|
||||
" [prompt] TEXT,\n"
|
||||
" [system] TEXT,\n"
|
||||
" [response] TEXT,\n"
|
||||
" [chat_id] INTEGER REFERENCES [log]([id])\n"
|
||||
")"
|
||||
)
|
||||
assert db["log"].columns_dict == {
|
||||
"id": int,
|
||||
"model": str,
|
||||
"timestamp": str,
|
||||
"prompt": str,
|
||||
"system": str,
|
||||
"response": str,
|
||||
"chat_id": int,
|
||||
"debug": str,
|
||||
"duration_ms": int,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue