In progress llm embed-multi, refs #215

This commit is contained in:
Simon Willison 2023-09-03 15:05:29 -07:00
parent 156bed7c65
commit 3d379fbddc

View file

@ -1,6 +1,7 @@
import click
from click_default_group import DefaultGroup
from dataclasses import asdict
import io
import json
from llm import (
Collection,
@ -29,6 +30,7 @@ import pydantic
from runpy import run_module
import shutil
import sqlite_utils
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Optional
@ -967,6 +969,130 @@ def embed(collection, id, input, model, store, database, content, metadata, form
click.echo(encode(embedding).hex())
@cli.command()
@click.argument("collection")
@click.argument(
"input_path",
type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True),
required=False,
)
@click.option(
"--format",
type=click.Choice(["json", "csv", "tsv", "nl"]),
help="Format of input file - defaults to auto-detect",
)
@click.option(
"--files",
type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str),
multiple=True,
help="Embed files in this directory - specify directory and glob pattern",
)
@click.option("--sql", help="Read input using this SQL query")
@click.option(
"--attach",
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
multiple=True,
help="Additional databases to attach - specify alias and file path",
)
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
def embed_multi(
collection, input_path, format, files, sql, attach, model, store, database
):
"""
Store embeddings for multiple strings at once
Input can be CSV, TSV or a JSON list of objects.
The first column is treated as an ID - all other columns
are assumed to be text that should be concatenated together
in order to calculate the embeddings.
"""
if not input_path and not sql and not files:
raise click.UsageError("Either --sql or input path or --files is required")
if files:
if input_path or sql or format:
raise click.UsageError(
"Cannot use --files with --sql, input path or --format"
)
if database:
db = sqlite_utils.Database(database)
else:
db = sqlite_utils.Database(user_dir() / "embeddings.db")
for alias, attach_path in attach:
db.attach(alias, attach_path)
collection_obj = Collection(
collection, db=db, model_id=model or get_default_embedding_model()
)
expected_length = None
if files:
def count_files():
i = 0
for directory, pattern in files:
for path in pathlib.Path(directory).glob(pattern):
i += 1
return i
def iterate_files():
for directory, pattern in files:
for path in pathlib.Path(directory).glob(pattern):
relative = path.relative_to(directory)
yield {"id": str(relative), "content": path.read_text()}
expected_length = count_files()
rows = iterate_files()
elif sql:
rows = db.query(sql)
count_sql = "select count(*) as c from ({})".format(sql)
expected_length = next(db.query(count_sql))["c"]
else:
def load_rows(fp):
return rows_from_file(fp, Format[format.upper()] if format else None)[0]
try:
if input_path != "-":
# Read the file twice - first time is to get a count
expected_length = 0
with open(input_path, "rb") as fp:
for _ in load_rows(fp):
expected_length += 1
rows = load_rows(
open(input_path, "rb")
if input_path != "-"
else io.BufferedReader(sys.stdin.buffer)
)
except json.JSONDecodeError as ex:
raise click.ClickException(str(ex))
with click.progressbar(
rows, label="Embedding", show_percent=True, length=expected_length
) as rows:
def tuples():
for row in rows:
values = list(row.values())
id = values[0]
text = " ".join(v or "" for v in values[1:])
yield id, text
# collection_obj.max_batch_size = 1
collection_obj.embed_multi(tuples(), store=store)
@cli.command()
@click.argument("collection")
@click.argument("id", required=False)