mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-12 15:53:11 +00:00
-c option for continuing a chat (using new chat_id column) (#14)
Refs #6 * Add a chat_id to requests * Fail early if the log.db is absent * Automatically add the chat_id column. Co-authored-by: Simon Willison <swillison@gmail.com>
This commit is contained in:
parent
d222550f66
commit
37999ce641
1 changed files with 52 additions and 4 deletions
56
llm/cli.py
56
llm/cli.py
|
|
@ -33,8 +33,17 @@ def cli():
|
|||
@click.option("-m", "--model", help="Model to use")
|
||||
@click.option("-s", "--stream", is_flag=True, help="Stream output")
|
||||
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
|
||||
@click.option(
|
||||
"-c",
|
||||
"--continue",
|
||||
is_flag=False,
|
||||
flag_value=-1,
|
||||
help="Continue the last conversation. Optionally takes a chat ID of a specific conversation.",
|
||||
default=None,
|
||||
type=int,
|
||||
)
|
||||
@click.option("--code", is_flag=True, help="System prompt to optimize for code output")
|
||||
def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
|
||||
def chatgpt(prompt, system, gpt4, model, stream, no_log, code, chat_id):
|
||||
"Execute prompt against ChatGPT"
|
||||
if prompt is None:
|
||||
# Read from stdin instead
|
||||
|
|
@ -49,6 +58,13 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
|
|||
if code:
|
||||
system = CODE_SYSTEM_PROMPT
|
||||
messages = []
|
||||
chat_id, history = get_history(chat_id)
|
||||
if history:
|
||||
for entry in history:
|
||||
if entry.get("system"):
|
||||
messages.append({"role": "system", "content": entry["system"]})
|
||||
messages.append({"role": "user", "content": entry["prompt"]})
|
||||
messages.append({"role": "assistant", "content": entry["response"]})
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
|
@ -66,14 +82,14 @@ def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
|
|||
print(content, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
log(no_log, "chatgpt", system, prompt, "".join(response), model)
|
||||
log(no_log, "chatgpt", system, prompt, "".join(response), model, chat_id)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
log(no_log, "chatgpt", system, prompt, content, model)
|
||||
log(no_log, "chatgpt", system, prompt, content, model, chat_id)
|
||||
if code:
|
||||
content = unwrap_markdown(content)
|
||||
print(content)
|
||||
|
|
@ -133,7 +149,7 @@ def get_log_db_path():
|
|||
return os.path.expanduser("~/.llm/log.db")
|
||||
|
||||
|
||||
def log(no_log, provider, system, prompt, response, model):
|
||||
def log(no_log, provider, system, prompt, response, model, chat_id=None):
|
||||
if no_log:
|
||||
return
|
||||
log_path = get_log_db_path()
|
||||
|
|
@ -145,6 +161,7 @@ def log(no_log, provider, system, prompt, response, model):
|
|||
"provider": provider,
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
"chat_id": chat_id,
|
||||
"response": response,
|
||||
"model": model,
|
||||
"timestamp": str(datetime.datetime.utcnow()),
|
||||
|
|
@ -152,6 +169,37 @@ def log(no_log, provider, system, prompt, response, model):
|
|||
)
|
||||
|
||||
|
||||
def get_history(chat_id):
|
||||
if chat_id is None:
|
||||
return None, []
|
||||
log_path = get_log_db_path()
|
||||
if not os.path.exists(log_path):
|
||||
raise click.ClickException(
|
||||
"This feature requires logging. Run `llm init-db` to create ~/.llm/log.db"
|
||||
)
|
||||
db = sqlite_utils.Database(log_path)
|
||||
# Check if the chat_id column exists in the DB. If not create it. This is a
|
||||
# migration path for people who have been using llm before chat_id was
|
||||
# added.
|
||||
if db["log"].columns and "chat_id" not in {
|
||||
column.name for column in db["log"].columns
|
||||
}:
|
||||
db["log"].add_column("chat_id", int)
|
||||
if chat_id == -1:
|
||||
# Return the most recent chat
|
||||
last_row = list(
|
||||
db["log"].rows_where(order_by="-rowid", limit=1, select="rowid, *")
|
||||
)
|
||||
if last_row:
|
||||
chat_id = last_row[0].get("chat_id") or last_row[0].get("rowid")
|
||||
else: # Database is empty
|
||||
return None, []
|
||||
rows = db["log"].rows_where(
|
||||
"rowid = ? or chat_id = ?", [chat_id, chat_id], order_by="rowid"
|
||||
)
|
||||
return chat_id, rows
|
||||
|
||||
|
||||
def unwrap_markdown(content):
|
||||
# Remove first and last line if they are triple backticks
|
||||
lines = [l for l in content.split("\n")]
|
||||
|
|
|
|||
Loading…
Reference in a new issue