Implemented new logs database schema

This commit is contained in:
Simon Willison 2023-07-03 07:27:47 -07:00
parent b1c51df3f1
commit 345ad0d2dc
7 changed files with 112 additions and 75 deletions

View file

@ -62,7 +62,7 @@ import sqlite_utils
import re
db = sqlite_utils.Database(memory=True)
migrate(db)
schema = db["log"].schema
schema = db["logs"].schema
def cleanup_sql(sql):
first_line = sql.split('(')[0]
@ -75,16 +75,19 @@ cog.out(
)
]]] -->
```sql
CREATE TABLE "log" (
CREATE TABLE "logs" (
[id] INTEGER PRIMARY KEY,
[model] TEXT,
[timestamp] TEXT,
[prompt] TEXT,
[system] TEXT,
[prompt_json] TEXT,
[options_json] TEXT,
[response] TEXT,
[chat_id] INTEGER REFERENCES [log]([id]),
[debug] TEXT,
[duration_ms] INTEGER
[response_json] TEXT,
[reply_to_id] INTEGER,
[chat_id] INTEGER REFERENCES "logs"([id]),
[duration_ms] INTEGER,
[datetime_utc] TEXT
);
```
<!-- [[[end]]] -->

View file

@ -1,7 +1,5 @@
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
import datetime
import json
from llm import Template
from .migrations import migrate
@ -19,7 +17,6 @@ from runpy import run_module
import shutil
import sqlite_utils
import sys
import time
import warnings
import yaml
@ -244,11 +241,7 @@ def prompt(
db = sqlite_utils.Database(log_path)
migrate(db)
log_message = response.to_log()
log_dict = asdict(log_message)
log_dict["duration_ms"] = response.duration_ms()
log_dict["timestamp_utc"] = response.timestamp_utc()
db["log2"].insert(log_dict, pk="id")
response.log_to_db(db)
# TODO: Figure out OpenAI exception handling
@ -382,7 +375,7 @@ def logs_list(count, path, truncate):
raise click.ClickException("No log database found at {}".format(path))
db = sqlite_utils.Database(path)
migrate(db)
rows = list(db["log"].rows_where(order_by="-id", limit=count or None))
rows = list(db["logs"].rows_where(order_by="-id", limit=count or None))
if truncate:
for row in rows:
row["prompt"] = _truncate_string(row["prompt"])
@ -585,31 +578,6 @@ def logs_db_path():
return user_dir() / "logs.db"
def log(no_log, system, prompt, response, model, chat_id=None, start=None):
duration_ms = None
if start is not None:
end = time.time()
duration_ms = int((end - start) * 1000)
if no_log:
return
log_path = logs_db_path()
if not log_path.exists():
return
db = sqlite_utils.Database(log_path)
migrate(db)
db["log"].insert(
{
"system": system,
"prompt": prompt,
"chat_id": chat_id,
"response": response,
"model": model,
"timestamp": str(datetime.datetime.utcnow()),
"duration_ms": duration_ms,
},
)
def load_template(name):
path = template_dir() / f"{name}.yaml"
if not path.exists():
@ -642,12 +610,12 @@ def get_history(chat_id):
migrate(db)
if chat_id == -1:
# Return the most recent chat
last_row = list(db["log"].rows_where(order_by="-id", limit=1))
last_row = list(db["logs"].rows_where(order_by="-id", limit=1))
if last_row:
chat_id = last_row[0].get("chat_id") or last_row[0].get("id")
else: # Database is empty
return None, []
rows = db["log"].rows_where(
rows = db["logs"].rows_where(
"id = ? or chat_id = ?", [chat_id, chat_id], order_by="id"
)
return chat_id, rows

View file

@ -97,13 +97,13 @@ class ChatResponse(Response):
self._response_json = response.to_dict_recursive()
yield response.choices[0].message.content
def to_log(self) -> LogMessage:
def log_message(self) -> LogMessage:
return LogMessage(
model=self.prompt.model.model_id,
prompt=self.prompt.prompt,
system=self.prompt.system,
options_json=not_nulls(self.prompt.options),
prompt_json=self._prompt_json,
options_json=not_nulls(self.prompt.options),
response=self.text(),
response_json=self.json(),
reply_to_id=None, # TODO

View file

@ -85,3 +85,55 @@ def m004_drop_provider(db):
def m005_debug(db):
db["log"].add_column("debug", str)
db["log"].add_column("duration_ms", int)
@migration
def m006_new_logs_table(db):
columns = db["log"].columns_dict
for column, type in (
("options_json", str),
("prompt_json", str),
("response_json", str),
("reply_to_id", int),
):
# It's possible people running development code like myself
# might have accidentally created these columns already
if column not in columns:
db["log"].add_column(column, type)
# Use .transform() to rename options and timestamp_utc, and set new order
db["log"].transform(
column_order=(
"id",
"model",
"prompt",
"system",
"prompt_json",
"options_json",
"response",
"response_json",
"reply_to_id",
"chat_id",
"duration_ms",
"timestamp_utc",
),
rename={
"timestamp": "timestamp_utc",
"options": "options_json",
},
)
@migration
def m007_finish_logs_table(db):
db["log"].transform(
drop={"debug"},
rename={"timestamp_utc": "datetime_utc"},
)
with db.conn:
db.execute("alter table log rename to logs")
@migration
def m008_reply_to_id_foreign_key(db):
db["logs"].add_foreign_key("reply_to_id", "logs", "id")

View file

@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, asdict
import datetime
import time
from typing import Any, Dict, Iterator, List, Optional, Set
@ -32,8 +32,8 @@ class LogMessage:
model: str # Actually the model.model_id string
prompt: str # Simplified string version of prompt
system: Optional[str] # Simplified string of system prompt
options_json: Dict[str, Any] # Any options e.g. temperature
prompt_json: Optional[Dict[str, Any]] # Detailed JSON of prompt
options_json: Dict[str, Any] # Any options e.g. temperature
response: str # Simplified string version of response
response_json: Optional[Dict[str, Any]] # Detailed JSON of response
reply_to_id: Optional[int] # ID of message this is a reply to
@ -95,15 +95,22 @@ class Response(ABC):
self._force()
return int((self._end - self._start) * 1000)
def timestamp_utc(self) -> str:
def datetime_utc(self) -> str:
self._force()
return self._start_utcnow.isoformat()
@abstractmethod
def to_log(self) -> LogMessage:
def log_message(self) -> LogMessage:
"Return a LogMessage of data to log"
pass
def log_to_db(self, db):
message = self.log_message()
message_dict = asdict(message)
message_dict["duration_ms"] = self.duration_ms()
message_dict["datetime_utc"] = self.datetime_utc()
db["logs"].insert(message_dict, pk="id")
class Model(ABC):
model_id: str

View file

@ -21,7 +21,7 @@ def test_logs(n, user_path):
log_path = str(user_path / "logs.db")
db = sqlite_utils.Database(log_path)
migrate(db)
db["log"].insert_all(
db["logs"].insert_all(
{
"system": "system",
"prompt": "prompt",
@ -66,7 +66,7 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path):
# Reset the log_path database
log_path = user_path / "logs.db"
log_db = sqlite_utils.Database(str(log_path))
log_db["log"].delete_where()
log_db["logs"].delete_where()
runner = CliRunner()
prompt = "three names for a pet pelican"
input = None
@ -82,7 +82,7 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path):
return
# Was it logged?
rows = list(log_db["log"].rows)
rows = list(log_db["logs"].rows)
assert len(rows) == 1
expected = {
"model": "gpt-3.5-turbo",

View file

@ -2,21 +2,38 @@ from llm.migrations import migrate
import sqlite_utils
EXPECTED = {
"id": int,
"model": str,
"prompt": str,
"system": str,
"prompt_json": str,
"options_json": str,
"response": str,
"response_json": str,
"reply_to_id": int,
"chat_id": int,
"duration_ms": int,
"datetime_utc": str,
}
def test_migrate_blank():
db = sqlite_utils.Database(memory=True)
migrate(db)
assert set(db.table_names()) == {"_llm_migrations", "log"}
assert db["log"].columns_dict == {
"id": int,
"model": str,
"timestamp": str,
"prompt": str,
"system": str,
"response": str,
"chat_id": int,
"debug": str,
"duration_ms": int,
}
assert set(db.table_names()) == {"_llm_migrations", "logs"}
assert db["logs"].columns_dict == EXPECTED
foreign_keys = db["logs"].foreign_keys
for expected_fk in (
sqlite_utils.db.ForeignKey(
table="logs", column="reply_to_id", other_table="logs", other_column="id"
),
sqlite_utils.db.ForeignKey(
table="logs", column="chat_id", other_table="logs", other_column="id"
),
):
assert expected_fk in foreign_keys
def test_migrate_from_original_schema():
@ -33,15 +50,5 @@ def test_migrate_from_original_schema():
},
)
migrate(db)
assert set(db.table_names()) == {"_llm_migrations", "log"}
assert db["log"].columns_dict == {
"id": int,
"model": str,
"timestamp": str,
"prompt": str,
"system": str,
"response": str,
"chat_id": int,
"debug": str,
"duration_ms": int,
}
assert set(db.table_names()) == {"_llm_migrations", "logs"}
assert db["logs"].columns_dict == EXPECTED