mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-28 14:58:18 +00:00
In progress llm embed-multi, refs #215
This commit is contained in:
parent
156bed7c65
commit
3d379fbddc
1 changed files with 126 additions and 0 deletions
126
llm/cli.py
126
llm/cli.py
|
|
@ -1,6 +1,7 @@
|
||||||
import click
|
import click
|
||||||
from click_default_group import DefaultGroup
|
from click_default_group import DefaultGroup
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from llm import (
|
from llm import (
|
||||||
Collection,
|
Collection,
|
||||||
|
|
@ -29,6 +30,7 @@ import pydantic
|
||||||
from runpy import run_module
|
from runpy import run_module
|
||||||
import shutil
|
import shutil
|
||||||
import sqlite_utils
|
import sqlite_utils
|
||||||
|
from sqlite_utils.utils import rows_from_file, Format
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import cast, Optional
|
from typing import cast, Optional
|
||||||
|
|
@ -967,6 +969,130 @@ def embed(collection, id, input, model, store, database, content, metadata, form
|
||||||
click.echo(encode(embedding).hex())
|
click.echo(encode(embedding).hex())
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("collection")
|
||||||
|
@click.argument(
|
||||||
|
"input_path",
|
||||||
|
type=click.Path(exists=True, dir_okay=False, allow_dash=True, readable=True),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--format",
|
||||||
|
type=click.Choice(["json", "csv", "tsv", "nl"]),
|
||||||
|
help="Format of input file - defaults to auto-detect",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--files",
|
||||||
|
type=(click.Path(file_okay=False, dir_okay=True, allow_dash=False), str),
|
||||||
|
multiple=True,
|
||||||
|
help="Embed files in this directory - specify directory and glob pattern",
|
||||||
|
)
|
||||||
|
@click.option("--sql", help="Read input using this SQL query")
|
||||||
|
@click.option(
|
||||||
|
"--attach",
|
||||||
|
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
|
||||||
|
multiple=True,
|
||||||
|
help="Additional databases to attach - specify alias and file path",
|
||||||
|
)
|
||||||
|
@click.option("-m", "--model", help="Embedding model to use")
|
||||||
|
@click.option("--store", is_flag=True, help="Store the text itself in the database")
|
||||||
|
@click.option(
|
||||||
|
"-d",
|
||||||
|
"--database",
|
||||||
|
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
|
||||||
|
envvar="LLM_EMBEDDINGS_DB",
|
||||||
|
)
|
||||||
|
def embed_multi(
|
||||||
|
collection, input_path, format, files, sql, attach, model, store, database
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Store embeddings for multiple strings at once
|
||||||
|
|
||||||
|
Input can be CSV, TSV or a JSON list of objects.
|
||||||
|
|
||||||
|
The first column is treated as an ID - all other columns
|
||||||
|
are assumed to be text that should be concatenated together
|
||||||
|
in order to calculate the embeddings.
|
||||||
|
"""
|
||||||
|
if not input_path and not sql and not files:
|
||||||
|
raise click.UsageError("Either --sql or input path or --files is required")
|
||||||
|
|
||||||
|
if files:
|
||||||
|
if input_path or sql or format:
|
||||||
|
raise click.UsageError(
|
||||||
|
"Cannot use --files with --sql, input path or --format"
|
||||||
|
)
|
||||||
|
|
||||||
|
if database:
|
||||||
|
db = sqlite_utils.Database(database)
|
||||||
|
else:
|
||||||
|
db = sqlite_utils.Database(user_dir() / "embeddings.db")
|
||||||
|
|
||||||
|
for alias, attach_path in attach:
|
||||||
|
db.attach(alias, attach_path)
|
||||||
|
|
||||||
|
collection_obj = Collection(
|
||||||
|
collection, db=db, model_id=model or get_default_embedding_model()
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_length = None
|
||||||
|
if files:
|
||||||
|
|
||||||
|
def count_files():
|
||||||
|
i = 0
|
||||||
|
for directory, pattern in files:
|
||||||
|
for path in pathlib.Path(directory).glob(pattern):
|
||||||
|
i += 1
|
||||||
|
return i
|
||||||
|
|
||||||
|
def iterate_files():
|
||||||
|
for directory, pattern in files:
|
||||||
|
for path in pathlib.Path(directory).glob(pattern):
|
||||||
|
relative = path.relative_to(directory)
|
||||||
|
yield {"id": str(relative), "content": path.read_text()}
|
||||||
|
|
||||||
|
expected_length = count_files()
|
||||||
|
rows = iterate_files()
|
||||||
|
elif sql:
|
||||||
|
rows = db.query(sql)
|
||||||
|
count_sql = "select count(*) as c from ({})".format(sql)
|
||||||
|
expected_length = next(db.query(count_sql))["c"]
|
||||||
|
else:
|
||||||
|
|
||||||
|
def load_rows(fp):
|
||||||
|
return rows_from_file(fp, Format[format.upper()] if format else None)[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if input_path != "-":
|
||||||
|
# Read the file twice - first time is to get a count
|
||||||
|
expected_length = 0
|
||||||
|
with open(input_path, "rb") as fp:
|
||||||
|
for _ in load_rows(fp):
|
||||||
|
expected_length += 1
|
||||||
|
|
||||||
|
rows = load_rows(
|
||||||
|
open(input_path, "rb")
|
||||||
|
if input_path != "-"
|
||||||
|
else io.BufferedReader(sys.stdin.buffer)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError as ex:
|
||||||
|
raise click.ClickException(str(ex))
|
||||||
|
|
||||||
|
with click.progressbar(
|
||||||
|
rows, label="Embedding", show_percent=True, length=expected_length
|
||||||
|
) as rows:
|
||||||
|
|
||||||
|
def tuples():
|
||||||
|
for row in rows:
|
||||||
|
values = list(row.values())
|
||||||
|
id = values[0]
|
||||||
|
text = " ".join(v or "" for v in values[1:])
|
||||||
|
yield id, text
|
||||||
|
|
||||||
|
# collection_obj.max_batch_size = 1
|
||||||
|
collection_obj.embed_multi(tuples(), store=store)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("collection")
|
@click.argument("collection")
|
||||||
@click.argument("id", required=False)
|
@click.argument("id", required=False)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue