mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
llm embed-multi --batch-size option, closes #273
This commit is contained in:
parent
b9478e6a17
commit
33dee4762e
4 changed files with 44 additions and 1 deletions
|
|
@ -148,6 +148,7 @@ All three mechanisms support these options:
|
|||
- `-d database.db` to specify a different database file to store the embeddings in
|
||||
- `--store` to store the original content in the embeddings table in addition to the embedding vector
|
||||
- `--prefix` to prepend a prefix to the stored ID of each item
|
||||
- `--batch-size SIZE` to process embeddings in batches of the specified size
|
||||
|
||||
(embeddings-cli-embed-multi-csv-etc)=
|
||||
### Embedding data from a CSV, TSV or JSON file
|
||||
|
|
|
|||
|
|
@ -516,6 +516,7 @@ Options:
|
|||
--sql TEXT Read input using this SQL query
|
||||
--attach <TEXT FILE>... Additional databases to attach - specify alias
|
||||
and file path
|
||||
--batch-size INTEGER Batch size to use when running embeddings
|
||||
--prefix TEXT Prefix to add to the IDs
|
||||
-m, --model TEXT Embedding model to use
|
||||
--store Store the text itself in the database
|
||||
|
|
|
|||
|
|
@ -1182,6 +1182,9 @@ def embed(
|
|||
multiple=True,
|
||||
help="Additional databases to attach - specify alias and file path",
|
||||
)
|
||||
@click.option(
|
||||
"--batch-size", type=int, help="Batch size to use when running embeddings"
|
||||
)
|
||||
@click.option("--prefix", help="Prefix to add to the IDs", default="")
|
||||
@click.option("-m", "--model", help="Embedding model to use")
|
||||
@click.option("--store", is_flag=True, help="Store the text itself in the database")
|
||||
|
|
@ -1200,6 +1203,7 @@ def embed_multi(
|
|||
binary,
|
||||
sql,
|
||||
attach,
|
||||
batch_size,
|
||||
prefix,
|
||||
model,
|
||||
store,
|
||||
|
|
@ -1324,7 +1328,10 @@ def embed_multi(
|
|||
else:
|
||||
yield id, " ".join(v or "" for v in values[1:])
|
||||
|
||||
collection_obj.embed_multi(tuples(), store=store)
|
||||
embed_kwargs = {"store": store}
|
||||
if batch_size:
|
||||
embed_kwargs["batch_size"] = batch_size
|
||||
collection_obj.embed_multi(tuples(), **embed_kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
|||
|
|
@ -369,6 +369,40 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix):
|
|||
]
|
||||
|
||||
|
||||
def test_embed_multi_batch_size(embed_demo, tmpdir):
|
||||
db_path = str(tmpdir / "data.db")
|
||||
runner = CliRunner()
|
||||
sql = """
|
||||
with recursive cte (id) as (
|
||||
select 1
|
||||
union all
|
||||
select id+1 from cte where id < 100
|
||||
)
|
||||
select id, 'Row ' || cast(id as text) as value from cte
|
||||
"""
|
||||
assert getattr(embed_demo, "batch_count", 0) == 0
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"embed-multi",
|
||||
"rows",
|
||||
"--sql",
|
||||
sql,
|
||||
"-d",
|
||||
db_path,
|
||||
"-m",
|
||||
"embed-demo",
|
||||
"--store",
|
||||
"--batch-size",
|
||||
"8",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
db = sqlite_utils.Database(db_path)
|
||||
assert db["embeddings"].count == 100
|
||||
assert embed_demo.batch_count == 13
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_files(tmpdir):
|
||||
db_path = str(tmpdir / "files.db")
|
||||
|
|
|
|||
Loading…
Reference in a new issue