From 3d379fbddc35faaf1e7b57a42a7d38b4eef5dcbe Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 3 Sep 2023 15:05:29 -0700 Subject: [PATCH] In progress llm embed-multi, refs #215 --- llm/cli.py | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/llm/cli.py b/llm/cli.py index fa16488..0312415 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1,6 +1,7 @@ import click from click_default_group import DefaultGroup from dataclasses import asdict +import io import json from llm import ( Collection, @@ -29,6 +30,7 @@ import pydantic from runpy import run_module import shutil import sqlite_utils +from sqlite_utils.utils import rows_from_file, Format import sys import textwrap from typing import cast, Optional @@ -967,6 +969,130 @@ def embed(collection, id, input, model, store, database, content, metadata, form click.echo(encode(embedding).hex()) +@cli.command() +@click.argument("collection") +@click.argument( + "input_path", + type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True), + required=False, +) +@click.option( + "--format", + type=click.Choice(["json", "csv", "tsv", "nl"]), + help="Format of input file - defaults to auto-detect", +) +@click.option( + "--files", + type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str), + multiple=True, + help="Embed files in this directory - specify directory and glob pattern", +) +@click.option("--sql", help="Read input using this SQL query") +@click.option( + "--attach", + type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)), + multiple=True, + help="Additional databases to attach - specify alias and file path", +) +@click.option("-m", "--model", help="Embedding model to use") +@click.option("--store", is_flag=True, help="Store the text itself in the database") +@click.option( + "-d", + "--database", + type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True), + envvar="LLM_EMBEDDINGS_DB", +) +def embed_multi( + collection, input_path, format, files, sql, attach, model, store, database +): + """ + Store embeddings for multiple strings at once + + Input can be CSV, TSV or a JSON list of objects. + + The first column is treated as an ID - all other columns + are assumed to be text that should be concatenated together + in order to calculate the embeddings. + """ + if not input_path and not sql and not files: + raise click.UsageError("Either --sql or input path or --files is required") + + if files: + if input_path or sql or format: + raise click.UsageError( + "Cannot use --files with --sql, input path or --format" + ) + + if database: + db = sqlite_utils.Database(database) + else: + db = sqlite_utils.Database(user_dir() / "embeddings.db") + + for alias, attach_path in attach: + db.attach(alias, attach_path) + + collection_obj = Collection( + collection, db=db, model_id=model or get_default_embedding_model() + ) + + expected_length = None + if files: + + def count_files(): + i = 0 + for directory, pattern in files: + for path in pathlib.Path(directory).glob(pattern): + i += 1 + return i + + def iterate_files(): + for directory, pattern in files: + for path in pathlib.Path(directory).glob(pattern): + relative = path.relative_to(directory) + yield {"id": str(relative), "content": path.read_text()} + + expected_length = count_files() + rows = iterate_files() + elif sql: + rows = db.query(sql) + count_sql = "select count(*) as c from ({})".format(sql) + expected_length = next(db.query(count_sql))["c"] + else: + + def load_rows(fp): + return rows_from_file(fp, Format[format.upper()] if format else None)[0] + + try: + if input_path != "-": + # Read the file twice - first time is to get a count + expected_length = 0 + with open(input_path, "rb") as fp: + for _ in load_rows(fp): + expected_length += 1 + + rows = load_rows( + open(input_path, "rb") + if input_path != "-" + else io.BufferedReader(sys.stdin.buffer) + ) + except json.JSONDecodeError as ex: + raise click.ClickException(str(ex)) + + with click.progressbar( + rows, label="Embedding", show_percent=True, length=expected_length + ) as rows: + + def tuples(): + for row in rows: + values = list(row.values()) + id = values[0] + text = " ".join(v or "" for v in values[1:]) + yield id, text + + # collection_obj.max_batch_size = 1 + collection_obj.embed_multi(tuples(), store=store) + + @cli.command() @click.argument("collection") @click.argument("id", required=False)