From 345ad0d2dc014e3bf8e697251686e4b48e3bb55d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 3 Jul 2023 07:27:47 -0700 Subject: [PATCH] Implemented new logs database schema --- docs/logging.md | 15 +++++--- llm/cli.py | 40 ++------------------ llm/default_plugins/openai_models.py | 4 +- llm/migrations.py | 52 ++++++++++++++++++++++++++ llm/models.py | 15 ++++++-- tests/test_llm.py | 6 +-- tests/test_migrate.py | 55 ++++++++++++++++------------ 7 files changed, 112 insertions(+), 75 deletions(-) diff --git a/docs/logging.md b/docs/logging.md index 6c4e35c..e79775d 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -62,7 +62,7 @@ import sqlite_utils import re db = sqlite_utils.Database(memory=True) migrate(db) -schema = db["log"].schema +schema = db["logs"].schema def cleanup_sql(sql): first_line = sql.split('(')[0] @@ -75,16 +75,19 @@ cog.out( ) ]]] --> ```sql -CREATE TABLE "log" ( +CREATE TABLE "logs" ( [id] INTEGER PRIMARY KEY, [model] TEXT, - [timestamp] TEXT, [prompt] TEXT, [system] TEXT, + [prompt_json] TEXT, + [options_json] TEXT, [response] TEXT, - [chat_id] INTEGER REFERENCES [log]([id]), - [debug] TEXT, - [duration_ms] INTEGER + [response_json] TEXT, + [reply_to_id] INTEGER, + [chat_id] INTEGER REFERENCES "logs"([id]), + [duration_ms] INTEGER, + [datetime_utc] TEXT ); ``` diff --git a/llm/cli.py b/llm/cli.py index 77b78bd..7fcaf73 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1,7 +1,5 @@ import click from click_default_group import DefaultGroup -from dataclasses import asdict -import datetime import json from llm import Template from .migrations import migrate @@ -19,7 +17,6 @@ from runpy import run_module import shutil import sqlite_utils import sys -import time import warnings import yaml @@ -244,11 +241,7 @@ def prompt( db = sqlite_utils.Database(log_path) migrate(db) - log_message = response.to_log() - log_dict = asdict(log_message) - log_dict["duration_ms"] = response.duration_ms() - log_dict["timestamp_utc"] = response.timestamp_utc() - db["log2"].insert(log_dict, pk="id") + response.log_to_db(db) # TODO: Figure out OpenAI exception handling @@ -382,7 +375,7 @@ def logs_list(count, path, truncate): raise click.ClickException("No log database found at {}".format(path)) db = sqlite_utils.Database(path) migrate(db) - rows = list(db["log"].rows_where(order_by="-id", limit=count or None)) + rows = list(db["logs"].rows_where(order_by="-id", limit=count or None)) if truncate: for row in rows: row["prompt"] = _truncate_string(row["prompt"]) @@ -585,31 +578,6 @@ def logs_db_path(): return user_dir() / "logs.db" -def log(no_log, system, prompt, response, model, chat_id=None, start=None): - duration_ms = None - if start is not None: - end = time.time() - duration_ms = int((end - start) * 1000) - if no_log: - return - log_path = logs_db_path() - if not log_path.exists(): - return - db = sqlite_utils.Database(log_path) - migrate(db) - db["log"].insert( - { - "system": system, - "prompt": prompt, - "chat_id": chat_id, - "response": response, - "model": model, - "timestamp": str(datetime.datetime.utcnow()), - "duration_ms": duration_ms, - }, - ) - - def load_template(name): path = template_dir() / f"{name}.yaml" if not path.exists(): @@ -642,12 +610,12 @@ def get_history(chat_id): migrate(db) if chat_id == -1: # Return the most recent chat - last_row = list(db["log"].rows_where(order_by="-id", limit=1)) + last_row = list(db["logs"].rows_where(order_by="-id", limit=1)) if last_row: chat_id = last_row[0].get("chat_id") or last_row[0].get("id") else: # Database is empty return None, [] - rows = db["log"].rows_where( + rows = db["logs"].rows_where( "id = ? or chat_id = ?", [chat_id, chat_id], order_by="id" ) return chat_id, rows diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 3ce5824..956591d 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -97,13 +97,13 @@ class ChatResponse(Response): self._response_json = response.to_dict_recursive() yield response.choices[0].message.content - def to_log(self) -> LogMessage: + def log_message(self) -> LogMessage: return LogMessage( model=self.prompt.model.model_id, prompt=self.prompt.prompt, system=self.prompt.system, - options_json=not_nulls(self.prompt.options), prompt_json=self._prompt_json, + options_json=not_nulls(self.prompt.options), response=self.text(), response_json=self.json(), reply_to_id=None, # TODO diff --git a/llm/migrations.py b/llm/migrations.py index e3e47fb..2e8756b 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -85,3 +85,55 @@ def m004_drop_provider(db): def m005_debug(db): db["log"].add_column("debug", str) db["log"].add_column("duration_ms", int) + + +@migration +def m006_new_logs_table(db): + columns = db["log"].columns_dict + for column, type in ( + ("options_json", str), + ("prompt_json", str), + ("response_json", str), + ("reply_to_id", int), + ): + # It's possible people running development code like myself + # might have accidentally created these columns already + if column not in columns: + db["log"].add_column(column, type) + + # Use .transform() to rename options and timestamp_utc, and set new order + db["log"].transform( + column_order=( + "id", + "model", + "prompt", + "system", + "prompt_json", + "options_json", + "response", + "response_json", + "reply_to_id", + "chat_id", + "duration_ms", + "timestamp_utc", + ), + rename={ + "timestamp": "timestamp_utc", + "options": "options_json", + }, + ) + + +@migration +def m007_finish_logs_table(db): + db["log"].transform( + drop={"debug"}, + rename={"timestamp_utc": "datetime_utc"}, + ) + with db.conn: + db.execute("alter table log rename to logs") + + +@migration +def m008_reply_to_id_foreign_key(db): + db["logs"].add_foreign_key("reply_to_id", "logs", "id") diff --git a/llm/models.py b/llm/models.py index 195b5d9..46c7114 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict import datetime import time from typing import Any, Dict, Iterator, List, Optional, Set @@ -32,8 +32,8 @@ class LogMessage: model: str # Actually the model.model_id string prompt: str # Simplified string version of prompt system: Optional[str] # Simplified string of system prompt - options_json: Dict[str, Any] # Any options e.g. temperature prompt_json: Optional[Dict[str, Any]] # Detailed JSON of prompt + options_json: Dict[str, Any] # Any options e.g. temperature response: str # Simplified string version of response response_json: Optional[Dict[str, Any]] # Detailed JSON of response reply_to_id: Optional[int] # ID of message this is a reply to @@ -95,15 +95,22 @@ class Response(ABC): self._force() return int((self._end - self._start) * 1000) - def timestamp_utc(self) -> str: + def datetime_utc(self) -> str: self._force() return self._start_utcnow.isoformat() @abstractmethod - def to_log(self) -> LogMessage: + def log_message(self) -> LogMessage: "Return a LogMessage of data to log" pass + def log_to_db(self, db): + message = self.log_message() + message_dict = asdict(message) + message_dict["duration_ms"] = self.duration_ms() + message_dict["datetime_utc"] = self.datetime_utc() + db["logs"].insert(message_dict, pk="id") + class Model(ABC): model_id: str diff --git a/tests/test_llm.py b/tests/test_llm.py index c0de921..c68f092 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -21,7 +21,7 @@ def test_logs(n, user_path): log_path = str(user_path / "logs.db") db = sqlite_utils.Database(log_path) migrate(db) - db["log"].insert_all( + db["logs"].insert_all( { "system": "system", "prompt": "prompt", @@ -66,7 +66,7 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path): # Reset the log_path database log_path = user_path / "logs.db" log_db = sqlite_utils.Database(str(log_path)) - log_db["log"].delete_where() + log_db["logs"].delete_where() runner = CliRunner() prompt = "three names for a pet pelican" input = None @@ -82,7 +82,7 @@ def test_llm_default_prompt(mocked_openai, use_stdin, user_path): return # Was it logged? - rows = list(log_db["log"].rows) + rows = list(log_db["logs"].rows) assert len(rows) == 1 expected = { "model": "gpt-3.5-turbo", diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 9450af7..64ffa39 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -2,21 +2,38 @@ from llm.migrations import migrate import sqlite_utils +EXPECTED = { + "id": int, + "model": str, + "prompt": str, + "system": str, + "prompt_json": str, + "options_json": str, + "response": str, + "response_json": str, + "reply_to_id": int, + "chat_id": int, + "duration_ms": int, + "datetime_utc": str, +} + + def test_migrate_blank(): db = sqlite_utils.Database(memory=True) migrate(db) - assert set(db.table_names()) == {"_llm_migrations", "log"} - assert db["log"].columns_dict == { - "id": int, - "model": str, - "timestamp": str, - "prompt": str, - "system": str, - "response": str, - "chat_id": int, - "debug": str, - "duration_ms": int, - } + assert set(db.table_names()) == {"_llm_migrations", "logs"} + assert db["logs"].columns_dict == EXPECTED + + foreign_keys = db["logs"].foreign_keys + for expected_fk in ( + sqlite_utils.db.ForeignKey( + table="logs", column="reply_to_id", other_table="logs", other_column="id" + ), + sqlite_utils.db.ForeignKey( + table="logs", column="chat_id", other_table="logs", other_column="id" + ), + ): + assert expected_fk in foreign_keys def test_migrate_from_original_schema(): @@ -33,15 +50,5 @@ def test_migrate_from_original_schema(): }, ) migrate(db) - assert set(db.table_names()) == {"_llm_migrations", "log"} - assert db["log"].columns_dict == { - "id": int, - "model": str, - "timestamp": str, - "prompt": str, - "system": str, - "response": str, - "chat_id": int, - "debug": str, - "duration_ms": int, - } + assert set(db.table_names()) == {"_llm_migrations", "logs"} + assert db["logs"].columns_dict == EXPECTED