llm/llm/cli.py
2023-05-17 14:08:53 -07:00

159 lines
4.6 KiB
Python

import click
from click_default_group import DefaultGroup
import datetime
import json
import openai
import os
import sqlite_utils
import sys
import warnings
warnings.simplefilter("ignore", ResourceWarning)
CODE_SYSTEM_PROMPT = """
You are a code generating tool. Return just the code, with no explanation
or context other than comments in the code itself.
""".strip()
@click.group(
cls=DefaultGroup,
default="chatgpt",
default_if_no_args=True,
)
@click.version_option()
def cli():
"Access large language models from the command-line"
@cli.command()
@click.argument("prompt", required=False)
@click.option("--system", help="System prompt to use")
@click.option("-4", "--gpt4", is_flag=True, help="Use GPT-4")
@click.option("-m", "--model", help="Model to use")
@click.option("-s", "--stream", is_flag=True, help="Stream output")
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option("--code", is_flag=True, help="System prompt to optimize for code output")
def chatgpt(prompt, system, gpt4, model, stream, no_log, code):
"Execute prompt against ChatGPT"
if prompt is None:
# Read from stdin instead
prompt = sys.stdin.read()
openai.api_key = get_openai_api_key()
if gpt4:
model = "gpt-4"
if not model:
model = "gpt-3.5-turbo"
if code and system:
raise click.ClickException("Cannot use --code and --system together")
if code:
system = CODE_SYSTEM_PROMPT
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
if stream:
response = []
for chunk in openai.ChatCompletion.create(
model=model,
messages=messages,
stream=True,
):
content = chunk["choices"][0].get("delta", {}).get("content")
if content is not None:
response.append(content)
print(content, end="")
sys.stdout.flush()
print("")
log(no_log, "chatgpt", system, prompt, "".join(response), model)
else:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
content = response.choices[0].message.content
log(no_log, "chatgpt", system, prompt, content, model)
if code:
content = unwrap_markdown(content)
print(content)
@cli.command()
def init_db():
"Ensure ~/.llm/log.db SQLite database exists"
path = get_log_db_path()
if os.path.exists(path):
return
# Ensure directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)
db = sqlite_utils.Database(path)
db.vacuum()
@cli.command()
@click.option(
"-n",
"--count",
default=3,
help="Number of entries to show - 0 for all",
)
@click.option(
"-p",
"--path",
type=click.Path(readable=True, exists=True, dir_okay=False),
help="Path to log database",
)
def logs(count, path):
path = path or get_log_db_path()
if not os.path.exists(path):
raise click.ClickException("No log database found at {}".format(path))
db = sqlite_utils.Database(path)
rows = db["log"].rows_where(order_by="-rowid", limit=count or None)
click.echo(json.dumps(list(rows), indent=2))
def get_openai_api_key():
# Expand this to home directory / ~.openai-api-key.txt
if "OPENAI_API_KEY" in os.environ:
return os.environ["OPENAI_API_KEY"]
path = os.path.expanduser("~/.openai-api-key.txt")
# If the file exists, read it
if os.path.exists(path):
with open(path) as fp:
return fp.read().strip()
raise click.ClickException(
"No OpenAI API key found. Set OPENAI_API_KEY environment variable or create ~/.openai-api-key.txt"
)
def get_log_db_path():
return os.path.expanduser("~/.llm/log.db")
def log(no_log, provider, system, prompt, response, model):
if no_log:
return
log_path = get_log_db_path()
if not os.path.exists(log_path):
return
db = sqlite_utils.Database(log_path)
db["log"].insert(
{
"provider": provider,
"system": system,
"prompt": prompt,
"response": response,
"model": model,
"timestamp": str(datetime.datetime.utcnow()),
}
)
def unwrap_markdown(content):
# Remove first and last line if they are triple backticks
lines = [l for l in content.split("\n")]
if lines[0].strip().startswith("```"):
lines = lines[1:]
if lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines)