diff --git a/llm/cli.py b/llm/cli.py index ef54dbe..93ab433 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -33,8 +33,17 @@ def cli(): @click.option("-m", "--model", help="Model to use") @click.option("-s", "--stream", is_flag=True, help="Stream output") @click.option("-n", "--no-log", is_flag=True, help="Don't log to database") +@click.option( + "-c", + "--continue", + is_flag=False, + flag_value=-1, + help="Continue the last conversation. Optionally takes a chat ID of a specific conversation.", + default=None, + type=int, +) @click.option("--code", is_flag=True, help="System prompt to optimize for code output") -def chatgpt(prompt, system, gpt4, model, stream, no_log, code): +def chatgpt(prompt, system, gpt4, model, stream, no_log, code, chat_id): "Execute prompt against ChatGPT" if prompt is None: # Read from stdin instead @@ -49,6 +58,13 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code): if code: system = CODE_SYSTEM_PROMPT messages = [] + chat_id, history = get_history(chat_id) + if history: + for entry in history: + if entry.get("system"): + messages.append({"role": "system", "content": entry["system"]}) + messages.append({"role": "user", "content": entry["prompt"]}) + messages.append({"role": "assistant", "content": entry["response"]}) if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) @@ -66,14 +82,14 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code): print(content, end="") sys.stdout.flush() print("") - log(no_log, "chatgpt", system, prompt, "".join(response), model) + log(no_log, "chatgpt", system, prompt, "".join(response), model, chat_id) else: response = openai.ChatCompletion.create( model=model, messages=messages, ) content = response.choices[0].message.content - log(no_log, "chatgpt", system, prompt, content, model) + log(no_log, "chatgpt", system, prompt, content, model, chat_id) if code: content = unwrap_markdown(content) print(content) @@ -133,7 +149,7 @@ def get_log_db_path(): return os.path.expanduser("~/.llm/log.db") -def log(no_log, provider, system, prompt, response, model): +def log(no_log, provider, system, prompt, response, model, chat_id=None): if no_log: return log_path = get_log_db_path() @@ -145,6 +161,7 @@ def log(no_log, provider, system, prompt, response, model): "provider": provider, "system": system, "prompt": prompt, + "chat_id": chat_id, "response": response, "model": model, "timestamp": str(datetime.datetime.utcnow()), @@ -152,6 +169,37 @@ def log(no_log, provider, system, prompt, response, model): ) +def get_history(chat_id): + if chat_id is None: + return None, [] + log_path = get_log_db_path() + if not os.path.exists(log_path): + raise click.ClickException( + "This feature requires logging. Run `llm init-db` to create ~/.llm/log.db" + ) + db = sqlite_utils.Database(log_path) + # Check if the chat_id column exists in the DB. If not create it. This is a + # migration path for people who have been using llm before chat_id was + # added. + if db["log"].columns and "chat_id" not in { + column.name for column in db["log"].columns + }: + db["log"].add_column("chat_id", int) + if chat_id == -1: + # Return the most recent chat + last_row = list( + db["log"].rows_where(order_by="-rowid", limit=1, select="rowid, *") + ) + if last_row: + chat_id = last_row[0].get("chat_id") or last_row[0].get("rowid") + else: # Database is empty + return None, [] + rows = db["log"].rows_where( + "rowid = ? or chat_id = ?", [chat_id, chat_id], order_by="rowid" + ) + return chat_id, rows + + def unwrap_markdown(content): # Remove first and last line if they are triple backticks lines = [l for l in content.split("\n")]