mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-10 00:10:58 +00:00
In progress llm embed-multi, refs #215
This commit is contained in:
parent
156bed7c65
commit
3d379fbddc
1 changed files with 126 additions and 0 deletions
126
llm/cli.py
126
llm/cli.py
|
|
@ -1,6 +1,7 @@
|
|||
import click
|
||||
from click_default_group import DefaultGroup
|
||||
from dataclasses import asdict
|
||||
import io
|
||||
import json
|
||||
from llm import (
|
||||
Collection,
|
||||
|
|
@ -29,6 +30,7 @@ import pydantic
|
|||
from runpy import run_module
|
||||
import shutil
|
||||
import sqlite_utils
|
||||
from sqlite_utils.utils import rows_from_file, Format
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional
|
||||
|
|
@ -967,6 +969,130 @@ def embed(collection, id, input, model, store, database, content, metadata, form
|
|||
click.echo(encode(embedding).hex())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("collection")
|
||||
@click.argument(
|
||||
"input_path",
|
||||
type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True),
|
||||
required=False,
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
type=click.Choice(["json", "csv", "tsv", "nl"]),
|
||||
help="Format of input file - defaults to auto-detect",
|
||||
)
|
||||
@click.option(
|
||||
"--files",
|
||||
type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str),
|
||||
multiple=True,
|
||||
help="Embed files in this directory - specify directory and glob pattern",
|
||||
)
|
||||
@click.option("--sql", help="Read input using this SQL query")
|
||||
@click.option(
|
||||
"--attach",
|
||||
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
|
||||
multiple=True,
|
||||
help="Additional databases to attach - specify alias and file path",
|
||||
)
|
||||
@click.option("-m", "--model", help="Embedding model to use")
|
||||
@click.option("--store", is_flag=True, help="Store the text itself in the database")
|
||||
@click.option(
|
||||
"-d",
|
||||
"--database",
|
||||
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
|
||||
envvar="LLM_EMBEDDINGS_DB",
|
||||
)
|
||||
def embed_multi(
|
||||
collection, input_path, format, files, sql, attach, model, store, database
|
||||
):
|
||||
"""
|
||||
Store embeddings for multiple strings at once
|
||||
|
||||
Input can be CSV, TSV or a JSON list of objects.
|
||||
|
||||
The first column is treated as an ID - all other columns
|
||||
are assumed to be text that should be concatenated together
|
||||
in order to calculate the embeddings.
|
||||
"""
|
||||
if not input_path and not sql and not files:
|
||||
raise click.UsageError("Either --sql or input path or --files is required")
|
||||
|
||||
if files:
|
||||
if input_path or sql or format:
|
||||
raise click.UsageError(
|
||||
"Cannot use --files with --sql, input path or --format"
|
||||
)
|
||||
|
||||
if database:
|
||||
db = sqlite_utils.Database(database)
|
||||
else:
|
||||
db = sqlite_utils.Database(user_dir() / "embeddings.db")
|
||||
|
||||
for alias, attach_path in attach:
|
||||
db.attach(alias, attach_path)
|
||||
|
||||
collection_obj = Collection(
|
||||
collection, db=db, model_id=model or get_default_embedding_model()
|
||||
)
|
||||
|
||||
expected_length = None
|
||||
if files:
|
||||
|
||||
def count_files():
|
||||
i = 0
|
||||
for directory, pattern in files:
|
||||
for path in pathlib.Path(directory).glob(pattern):
|
||||
i += 1
|
||||
return i
|
||||
|
||||
def iterate_files():
|
||||
for directory, pattern in files:
|
||||
for path in pathlib.Path(directory).glob(pattern):
|
||||
relative = path.relative_to(directory)
|
||||
yield {"id": str(relative), "content": path.read_text()}
|
||||
|
||||
expected_length = count_files()
|
||||
rows = iterate_files()
|
||||
elif sql:
|
||||
rows = db.query(sql)
|
||||
count_sql = "select count(*) as c from ({})".format(sql)
|
||||
expected_length = next(db.query(count_sql))["c"]
|
||||
else:
|
||||
|
||||
def load_rows(fp):
|
||||
return rows_from_file(fp, Format[format.upper()] if format else None)[0]
|
||||
|
||||
try:
|
||||
if input_path != "-":
|
||||
# Read the file twice - first time is to get a count
|
||||
expected_length = 0
|
||||
with open(input_path, "rb") as fp:
|
||||
for _ in load_rows(fp):
|
||||
expected_length += 1
|
||||
|
||||
rows = load_rows(
|
||||
open(input_path, "rb")
|
||||
if input_path != "-"
|
||||
else io.BufferedReader(sys.stdin.buffer)
|
||||
)
|
||||
except json.JSONDecodeError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
with click.progressbar(
|
||||
rows, label="Embedding", show_percent=True, length=expected_length
|
||||
) as rows:
|
||||
|
||||
def tuples():
|
||||
for row in rows:
|
||||
values = list(row.values())
|
||||
id = values[0]
|
||||
text = " ".join(v or "" for v in values[1:])
|
||||
yield id, text
|
||||
|
||||
# collection_obj.max_batch_size = 1
|
||||
collection_obj.embed_multi(tuples(), store=store)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("collection")
|
||||
@click.argument("id", required=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue