-c option for continuing a chat (using new chat_id column) (#14)

Refs #6

* Add a chat_id to requests
* Fail early if the log.db is absent
* Automatically add the chat_id column.

Co-authored-by: Simon Willison <swillison@gmail.com>
This commit is contained in:
Amjith Ramanujam 2023-06-13 23:03:35 -07:00 committed by GitHub
parent d222550f66
commit 37999ce641

View file

@ -33,8 +33,17 @@ def cli():
@click.option("-m", "--model", help="Model to use")
@click.option("-s", "--stream", is_flag=True, help="Stream output")
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option(
"-c",
"--continue",
is_flag=False,
flag_value=-1,
help="Continue the last conversation. Optionally takes a chat ID of a specific conversation.",
default=None,
type=int,
)
@click.option("--code", is_flag=True, help="System prompt to optimize for code output")
def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
def chatgpt(prompt, system, gpt4, model, stream, no_log, code, chat_id):
"Execute prompt against ChatGPT"
if prompt is None:
# Read from stdin instead
@ -49,6 +58,13 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
if code:
system = CODE_SYSTEM_PROMPT
messages = []
chat_id, history = get_history(chat_id)
if history:
for entry in history:
if entry.get("system"):
messages.append({"role": "system", "content": entry["system"]})
messages.append({"role": "user", "content": entry["prompt"]})
messages.append({"role": "assistant", "content": entry["response"]})
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
@ -66,14 +82,14 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
print(content, end="")
sys.stdout.flush()
print("")
log(no_log, "chatgpt", system, prompt, "".join(response), model)
log(no_log, "chatgpt", system, prompt, "".join(response), model, chat_id)
else:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
content = response.choices[0].message.content
log(no_log, "chatgpt", system, prompt, content, model)
log(no_log, "chatgpt", system, prompt, content, model, chat_id)
if code:
content = unwrap_markdown(content)
print(content)
@ -133,7 +149,7 @@ def get_log_db_path():
return os.path.expanduser("~/.llm/log.db")
def log(no_log, provider, system, prompt, response, model):
def log(no_log, provider, system, prompt, response, model, chat_id=None):
if no_log:
return
log_path = get_log_db_path()
@ -145,6 +161,7 @@ def log(no_log, provider, system, prompt, response, model):
"provider": provider,
"system": system,
"prompt": prompt,
"chat_id": chat_id,
"response": response,
"model": model,
"timestamp": str(datetime.datetime.utcnow()),
@ -152,6 +169,37 @@ def log(no_log, provider, system, prompt, response, model):
)
def get_history(chat_id):
if chat_id is None:
return None, []
log_path = get_log_db_path()
if not os.path.exists(log_path):
raise click.ClickException(
"This feature requires logging. Run `llm init-db` to create ~/.llm/log.db"
)
db = sqlite_utils.Database(log_path)
# Check if the chat_id column exists in the DB. If not create it. This is a
# migration path for people who have been using llm before chat_id was
# added.
if db["log"].columns and "chat_id" not in {
column.name for column in db["log"].columns
}:
db["log"].add_column("chat_id", int)
if chat_id == -1:
# Return the most recent chat
last_row = list(
db["log"].rows_where(order_by="-rowid", limit=1, select="rowid, *")
)
if last_row:
chat_id = last_row[0].get("chat_id") or last_row[0].get("rowid")
else: # Database is empty
return None, []
rows = db["log"].rows_where(
"rowid = ? or chat_id = ?", [chat_id, chat_id], order_by="rowid"
)
return chat_id, rows
def unwrap_markdown(content):
# Remove first and last line if they are triple backticks
lines = [l for l in content.split("\n")]