llm/llm/cli.py

598 lines
16 KiB
Python
Raw Normal View History

2023-04-01 21:28:24 +00:00
import click
from click_default_group import DefaultGroup
2023-04-01 22:28:46 +00:00
import datetime
2023-04-02 01:52:46 +00:00
import json
2023-06-17 07:40:46 +00:00
from llm import Template
from .migrations import migrate
from .plugins import pm, get_plugins, get_model_aliases, get_models_with_aliases
2023-04-01 21:28:24 +00:00
import openai
import os
import pathlib
2023-06-17 07:40:46 +00:00
import pydantic
from runpy import run_module
2023-06-17 07:40:46 +00:00
import shutil
2023-04-01 22:28:46 +00:00
import sqlite_utils
2023-06-17 07:40:46 +00:00
from string import Template as StringTemplate
2023-04-01 21:28:24 +00:00
import sys
2023-06-16 07:47:15 +00:00
import time
2023-04-02 01:37:55 +00:00
import warnings
2023-06-17 07:40:46 +00:00
import yaml
2023-04-02 01:37:55 +00:00
warnings.simplefilter("ignore", ResourceWarning)
2023-04-01 21:28:24 +00:00
DEFAULT_MODEL = "gpt-3.5-turbo"
2023-06-17 07:40:46 +00:00
DEFAULT_TEMPLATE = "prompt: "
2023-04-01 21:28:24 +00:00
@click.group(
cls=DefaultGroup,
default="prompt",
2023-04-01 21:28:24 +00:00
default_if_no_args=True,
)
2023-04-01 22:00:16 +00:00
@click.version_option()
2023-04-01 21:28:24 +00:00
def cli():
2023-06-17 08:17:52 +00:00
"""
Access large language models from the command-line
Documentation: https://llm.datasette.io/
To get started, obtain an OpenAI key and set it like this:
\b
$ llm keys set openai
Enter key: ...
Then execute a prompt like this:
llm 'Five outrageous names for a pet pelican'
"""
2023-04-01 21:28:24 +00:00
@cli.command(name="prompt")
@click.argument("prompt", required=False)
2023-07-01 17:46:20 +00:00
@click.option("-s", "--system", help="System prompt to use")
@click.option("model_id", "-m", "--model", help="Model to use")
2023-06-17 07:40:46 +00:00
@click.option("-t", "--template", help="Template to use")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Parameters for template",
)
@click.option("--no-stream", is_flag=True, help="Do not stream output")
2023-04-01 22:28:46 +00:00
@click.option("-n", "--no-log", is_flag=True, help="Don't log to database")
@click.option(
"_continue",
"-c",
"--continue",
is_flag=True,
flag_value=-1,
help="Continue the most recent conversation.",
)
@click.option(
"chat_id",
"--chat",
help="Continue the conversation with the given chat ID.",
type=int,
)
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
2023-06-17 07:40:46 +00:00
def prompt(
prompt,
system,
model_id,
template,
param,
no_stream,
no_log,
_continue,
chat_id,
key,
save,
2023-06-17 07:40:46 +00:00
):
2023-06-17 08:17:52 +00:00
"""
Execute a prompt
Documentation: https://llm.datasette.io/en/stable/usage.html
"""
if prompt is None:
2023-06-17 07:40:46 +00:00
if template:
# If running a template only consume from stdin if it has data
if not sys.stdin.isatty():
prompt = sys.stdin.read()
elif not save:
# Hang waiting for input to stdin (unless --save)
2023-06-17 07:40:46 +00:00
prompt = sys.stdin.read()
model_aliases = get_model_aliases()
if save:
# We are saving their prompt/system/etc to a new template
# Fields to save: prompt, system, model - and more in the future
bad_options = []
for option, var in (
("--template", template),
("--continue", _continue),
("--chat", chat_id),
):
if var:
bad_options.append(option)
if bad_options:
raise click.ClickException(
"--save cannot be used with {}".format(", ".join(bad_options))
)
path = template_dir() / f"{save}.yaml"
to_save = {}
if model_id:
try:
to_save["model"] = model_aliases[model_id].model_id
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
if prompt:
to_save["prompt"] = prompt
if system:
to_save["system"] = system
if param:
to_save["defaults"] = dict(param)
path.write_text(
yaml.dump(
to_save,
indent=4,
default_flow_style=False,
),
"utf-8",
)
return
2023-06-17 07:40:46 +00:00
if template:
params = dict(param)
# Cannot be used with system
if system:
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
try:
prompt, system = template_obj.execute(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
model_id = template_obj.model
history_model = None
# TODO: Re-introduce --continue mode
# messages = []
# if _continue:
# _continue = -1
# if chat_id:
# raise click.ClickException("Cannot use --continue and --chat together")
# else:
# _continue = chat_id
# chat_id, history = get_history(_continue)
# history_model = None
# if history:
# for entry in history:
# if entry.get("system"):
# messages.append({"role": "system", "content": entry["system"]})
# messages.append({"role": "user", "content": entry["prompt"]})
# messages.append({"role": "assistant", "content": entry["response"]})
# history_model = entry["model"]
# Figure out which model we are using
if model_id is None:
model_id = history_model or DEFAULT_MODEL
# Now resolve the model
try:
model = model_aliases[model_id]
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
# Provide the API key, if one is needed
if model.needs_key and not model.key:
model.key = get_key(key, model.needs_key, model.key_env_var)
2023-07-01 16:09:19 +00:00
should_stream = model.can_stream and not no_stream
if should_stream:
method = model.stream
else:
2023-07-01 16:09:19 +00:00
method = model.prompt
2023-07-01 16:09:19 +00:00
response = method(prompt, system)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
2023-07-01 16:09:19 +00:00
else:
print(response.text())
# TODO: Figure out OpenAI exception handling
# TODO: Log to database
return
# Original code:
try:
2023-06-16 07:47:15 +00:00
debug = {}
if no_stream:
2023-06-16 07:47:15 +00:00
start = time.time()
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
2023-06-16 07:47:15 +00:00
debug["model"] = response.model
debug["usage"] = response.usage
content = response.choices[0].message.content
2023-06-16 07:47:15 +00:00
log(no_log, system, prompt, content, model, chat_id, debug, start)
print(content)
else:
2023-06-16 07:47:15 +00:00
start = time.time()
response = []
for chunk in openai.ChatCompletion.create(
model=model,
messages=messages,
stream=True,
):
2023-06-16 07:47:15 +00:00
debug["model"] = chunk.model
content = chunk["choices"][0].get("delta", {}).get("content")
if content is not None:
response.append(content)
print(content, end="")
sys.stdout.flush()
print("")
2023-06-16 07:47:15 +00:00
log(no_log, system, prompt, "".join(response), model, chat_id, debug, start)
except openai.error.AuthenticationError as ex:
raise click.ClickException("{}: {}".format(ex.error.type, ex.error.code))
except openai.error.OpenAIError as ex:
raise click.ClickException(str(ex))
2023-04-01 22:28:46 +00:00
@cli.command()
def init_db():
2023-06-17 08:17:52 +00:00
"""
2023-06-17 08:29:36 +00:00
Ensure the logs.db SQLite database exists
2023-06-17 08:17:52 +00:00
All subsequent prompts will be logged to this database.
"""
path = log_db_path()
if path.exists():
2023-04-01 22:28:46 +00:00
return
# Ensure directory exists
path.parent.mkdir(parents=True, exist_ok=True)
2023-04-01 22:28:46 +00:00
db = sqlite_utils.Database(path)
db.vacuum()
2023-04-01 21:28:24 +00:00
@cli.group()
def keys():
2023-06-17 08:17:52 +00:00
"Manage stored API keys for different models"
@keys.command(name="path")
def keys_path_command():
2023-06-17 08:17:52 +00:00
"Output the path to the keys.json file"
click.echo(keys_path())
def keys_path():
llm_keys_path = os.environ.get("LLM_KEYS_PATH")
if llm_keys_path:
return pathlib.Path(llm_keys_path)
else:
return user_dir() / "keys.json"
@keys.command(name="set")
@click.argument("name")
@click.option("--value", prompt="Enter key", hide_input=True, help="Value to set")
def set_(name, value):
"""
2023-06-17 08:17:52 +00:00
Save a key in the keys.json file
Example usage:
\b
$ llm keys set openai
Enter key: ...
"""
default = {"// Note": "This file stores secret API credentials. Do not share!"}
path = pathlib.Path(keys_path())
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.write_text(json.dumps(default))
try:
current = json.loads(path.read_text())
except json.decoder.JSONDecodeError:
current = default
current[name] = value
path.write_text(json.dumps(current, indent=2) + "\n")
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def logs():
2023-06-17 08:17:52 +00:00
"Tools for exploring logged prompts and responses"
@logs.command(name="path")
def logs_path():
2023-06-17 08:29:36 +00:00
"Output the path to the logs.db file"
click.echo(log_db_path())
@logs.command(name="list")
2023-04-02 01:52:46 +00:00
@click.option(
"-n",
"--count",
default=3,
help="Number of entries to show - 0 for all",
)
@click.option(
"-p",
"--path",
type=click.Path(readable=True, exists=True, dir_okay=False),
help="Path to log database",
)
@click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output")
def logs_list(count, path, truncate):
2023-06-17 08:17:52 +00:00
"Show recent logged prompts and their responses"
path = pathlib.Path(path or log_db_path())
if not path.exists():
2023-04-02 01:52:46 +00:00
raise click.ClickException("No log database found at {}".format(path))
db = sqlite_utils.Database(path)
migrate(db)
2023-06-15 21:09:19 +00:00
rows = list(db["log"].rows_where(order_by="-id", limit=count or None))
if truncate:
for row in rows:
row["prompt"] = _truncate_string(row["prompt"])
row["response"] = _truncate_string(row["response"])
2023-04-02 01:52:46 +00:00
click.echo(json.dumps(list(rows), indent=2))
@cli.group()
def models():
"Manage available models"
@models.command(name="list")
def models_list():
"List available models"
for model_with_aliases in get_models_with_aliases():
extra = ""
if model_with_aliases.aliases:
extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases))
output = str(model_with_aliases.model) + extra
click.echo(output)
2023-06-17 07:40:46 +00:00
@cli.group()
def templates():
2023-06-17 08:17:52 +00:00
"Manage stored prompt templates"
2023-06-17 07:40:46 +00:00
@templates.command(name="list")
def templates_list():
2023-06-17 08:17:52 +00:00
"List available prompt templates"
2023-06-17 07:40:46 +00:00
path = template_dir()
pairs = []
for file in path.glob("*.yaml"):
name = file.stem
template = load_template(name)
text = []
if template.system:
text.append(f"system: {template.system}")
if template.prompt:
text.append(f" prompt: {template.prompt}")
else:
text = [template.prompt]
pairs.append((name, "".join(text).replace("\n", " ")))
try:
max_name_len = max(len(p[0]) for p in pairs)
except ValueError:
return
else:
fmt = "{name:<" + str(max_name_len) + "} : {prompt}"
for name, prompt in sorted(pairs):
text = fmt.format(name=name, prompt=prompt)
click.echo(display_truncated(text))
2023-06-17 07:40:46 +00:00
@cli.command(name="plugins")
def plugins_list():
"List installed plugins"
click.echo(json.dumps(get_plugins(), indent=2))
2023-06-17 07:40:46 +00:00
def display_truncated(text):
console_width = shutil.get_terminal_size()[0]
if len(text) > console_width:
return text[: console_width - 3] + "..."
else:
return text
@templates.command(name="show")
@click.argument("name")
def templates_show(name):
2023-06-17 08:17:52 +00:00
"Show the specified prompt template"
2023-06-17 07:40:46 +00:00
template = load_template(name)
click.echo(
yaml.dump(
dict((k, v) for k, v in template.dict().items() if v is not None),
indent=4,
default_flow_style=False,
)
)
@templates.command(name="edit")
@click.argument("name")
def templates_edit(name):
2023-06-17 08:17:52 +00:00
"Edit the specified prompt template using the default $EDITOR"
2023-06-17 07:40:46 +00:00
# First ensure it exists
path = template_dir() / f"{name}.yaml"
if not path.exists():
path.write_text(DEFAULT_TEMPLATE, "utf-8")
click.edit(filename=path)
# Validate that template
load_template(name)
@templates.command(name="path")
def templates_path():
2023-06-17 08:17:52 +00:00
"Output the path to the templates directory"
2023-06-17 07:40:46 +00:00
click.echo(template_dir())
@cli.command()
@click.argument("packages", nargs=-1, required=True)
@click.option(
"-U", "--upgrade", is_flag=True, help="Upgrade packages to latest version"
)
def install(packages, upgrade):
"""Install packages from PyPI into the same environment as LLM"""
args = ["pip", "install"]
if upgrade:
args += ["--upgrade"]
args += list(packages)
sys.argv = args
run_module("pip", run_name="__main__")
@cli.command()
@click.argument("packages", nargs=-1, required=True)
@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation")
def uninstall(packages, yes):
"""Uninstall Python packages from the LLM environment"""
sys.argv = ["pip", "uninstall"] + list(packages) + (["-y"] if yes else [])
run_module("pip", run_name="__main__")
2023-06-17 07:40:46 +00:00
def template_dir():
llm_templates_path = os.environ.get("LLM_TEMPLATES_PATH")
if llm_templates_path:
path = pathlib.Path(llm_templates_path)
else:
path = user_dir() / "templates"
path.mkdir(parents=True, exist_ok=True)
return path
def _truncate_string(s, max_length=100):
if len(s) > max_length:
return s[: max_length - 3] + "..."
return s
def get_key(key_arg, default_key, env_var=None):
keys = load_keys()
if key_arg in keys:
return keys[key_arg]
if key_arg:
return key_arg
if env_var and os.environ.get(env_var):
return os.environ[env_var]
default = keys.get(default_key)
if not default:
message = "No key found - add one using 'llm keys set {}'".format(default_key)
if env_var:
message += " or set the {} environment variable".format(env_var)
raise click.ClickException(message)
return default
def load_keys():
path = pathlib.Path(keys_path())
if path.exists():
return json.loads(path.read_text())
else:
return {}
2023-04-01 22:28:46 +00:00
def user_dir():
return pathlib.Path(click.get_app_dir("io.datasette.llm"))
def log_db_path():
llm_log_path = os.environ.get("LLM_LOG_PATH")
if llm_log_path:
return pathlib.Path(llm_log_path)
else:
2023-06-17 08:29:36 +00:00
return user_dir() / "logs.db"
2023-04-01 22:28:46 +00:00
2023-06-16 07:47:15 +00:00
def log(no_log, system, prompt, response, model, chat_id=None, debug=None, start=None):
duration_ms = None
if start is not None:
end = time.time()
duration_ms = int((end - start) * 1000)
2023-04-01 22:28:46 +00:00
if no_log:
return
log_path = log_db_path()
if not log_path.exists():
2023-04-01 22:28:46 +00:00
return
db = sqlite_utils.Database(log_path)
migrate(db)
2023-04-01 22:28:46 +00:00
db["log"].insert(
{
"system": system,
"prompt": prompt,
"chat_id": chat_id,
2023-04-01 22:28:46 +00:00
"response": response,
"model": model,
"timestamp": str(datetime.datetime.utcnow()),
2023-06-16 07:47:15 +00:00
"debug": debug,
"duration_ms": duration_ms,
},
2023-04-01 22:28:46 +00:00
)
2023-04-02 01:36:27 +00:00
2023-06-17 07:40:46 +00:00
def load_template(name):
path = template_dir() / f"{name}.yaml"
if not path.exists():
raise click.ClickException(f"Invalid template: {name}")
try:
loaded = yaml.safe_load(path.read_text())
except yaml.YAMLError as ex:
raise click.ClickException("Invalid YAML: {}".format(str(ex)))
if isinstance(loaded, str):
return Template(name=name, prompt=loaded)
loaded["name"] = name
try:
2023-07-01 18:56:35 +00:00
return Template.model_validate(loaded)
2023-06-17 07:40:46 +00:00
except pydantic.ValidationError as e:
msg = "A validation error occurred:"
for error in e.errors():
msg += f"\n {error['loc'][0]}: {error['msg']}"
raise click.ClickException(msg)
def get_history(chat_id):
if chat_id is None:
return None, []
log_path = log_db_path()
if not log_path.exists():
raise click.ClickException(
2023-06-17 08:29:36 +00:00
"This feature requires logging. Run `llm init-db` to create logs.db"
)
db = sqlite_utils.Database(log_path)
migrate(db)
if chat_id == -1:
# Return the most recent chat
last_row = list(db["log"].rows_where(order_by="-id", limit=1))
if last_row:
chat_id = last_row[0].get("chat_id") or last_row[0].get("id")
else: # Database is empty
return None, []
rows = db["log"].rows_where(
"id = ? or chat_id = ?", [chat_id, chat_id], order_by="id"
)
return chat_id, rows
pm.hook.register_commands(cli=cli)