llm embed-multi --batch-size option, closes #273

This commit is contained in:
Simon Willison 2023-09-13 16:33:27 -07:00
parent b9478e6a17
commit 33dee4762e
4 changed files with 44 additions and 1 deletions

View file

@ -148,6 +148,7 @@ All three mechanisms support these options:
- `-d database.db` to specify a different database file to store the embeddings in
- `--store` to store the original content in the embeddings table in addition to the embedding vector
- `--prefix` to prepend a prefix to the stored ID of each item
- `--batch-size SIZE` to process embeddings in batches of the specified size
(embeddings-cli-embed-multi-csv-etc)=
### Embedding data from a CSV, TSV or JSON file

View file

@ -516,6 +516,7 @@ Options:
--sql TEXT Read input using this SQL query
--attach <TEXT FILE>... Additional databases to attach - specify alias
and file path
--batch-size INTEGER Batch size to use when running embeddings
--prefix TEXT Prefix to add to the IDs
-m, --model TEXT Embedding model to use
--store Store the text itself in the database

View file

@ -1182,6 +1182,9 @@ def embed(
multiple=True,
help="Additional databases to attach - specify alias and file path",
)
@click.option(
"--batch-size", type=int, help="Batch size to use when running embeddings"
)
@click.option("--prefix", help="Prefix to add to the IDs", default="")
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@ -1200,6 +1203,7 @@ def embed_multi(
binary,
sql,
attach,
batch_size,
prefix,
model,
store,
@ -1324,7 +1328,10 @@ def embed_multi(
else:
yield id, " ".join(v or "" for v in values[1:])
collection_obj.embed_multi(tuples(), store=store)
embed_kwargs = {"store": store}
if batch_size:
embed_kwargs["batch_size"] = batch_size
collection_obj.embed_multi(tuples(), **embed_kwargs)
@cli.command()

View file

@ -369,6 +369,40 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix):
]
def test_embed_multi_batch_size(embed_demo, tmpdir):
db_path = str(tmpdir / "data.db")
runner = CliRunner()
sql = """
with recursive cte (id) as (
select 1
union all
select id+1 from cte where id < 100
)
select id, 'Row ' || cast(id as text) as value from cte
"""
assert getattr(embed_demo, "batch_count", 0) == 0
result = runner.invoke(
cli,
[
"embed-multi",
"rows",
"--sql",
sql,
"-d",
db_path,
"-m",
"embed-demo",
"--store",
"--batch-size",
"8",
],
)
assert result.exit_code == 0
db = sqlite_utils.Database(db_path)
assert db["embeddings"].count == 100
assert embed_demo.batch_count == 13
@pytest.fixture
def multi_files(tmpdir):
db_path = str(tmpdir / "files.db")