mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-24 15:04:46 +00:00
Binary embeddings (#254)
* Binary embeddings support, refs #253 * Write binary content to content_blob, with tests - refs #253 * supports_text and supports_binary embedding validation, refs #253
This commit is contained in:
parent
4fab55f253
commit
52cec1304b
11 changed files with 227 additions and 58 deletions
|
|
@ -10,8 +10,13 @@ embedding_model = llm.get_embedding_model("ada-002")
|
|||
To embed a string, returning a Python list of floating point numbers, use the `.embed()` method:
|
||||
```python
|
||||
vector = embedding_model.embed("my happy hound")
|
||||
|
||||
If the embedding model can handle binary input, you can call `.embed()` with a byte string instead. You can check the `supports_binary` property to see if this is supported:
|
||||
```python
|
||||
if embedding_model.supports_binary:
|
||||
vector = embedding_model.embed(open("my-image.jpg", "rb").read())
|
||||
```
|
||||
Many embeddings models are more efficient when you embed multiple strings at once. To embed multiple strings at once, use the `.embed_multi()` method:
|
||||
Many embeddings models are more efficient when you embed multiple strings or binary strings at once. To embed multiple strings at once, use the `.embed_multi()` method:
|
||||
```python
|
||||
vectors = list(embedding_model.embed_multi(["my happy hound", "my dissatisfied cat"]))
|
||||
```
|
||||
|
|
@ -63,7 +68,7 @@ This additional metadata will be stored as JSON in the `metadata` column of the
|
|||
(embeddings-python-bulk)=
|
||||
### Storing embeddings in bulk
|
||||
|
||||
The `collection.embed_multi()` method can be used to store embeddings for multiple strings at once. This can be more efficient for some embedding models.
|
||||
The `collection.embed_multi()` method can be used to store embeddings for multiple items at once. This can be more efficient for some embedding models.
|
||||
|
||||
```python
|
||||
collection.embed_multi(
|
||||
|
|
@ -177,6 +182,7 @@ CREATE TABLE "embeddings" (
|
|||
[id] TEXT,
|
||||
[embedding] BLOB,
|
||||
[content] TEXT,
|
||||
[content_blob] BLOB,
|
||||
[content_hash] BLOB,
|
||||
[metadata] TEXT,
|
||||
[updated] INTEGER,
|
||||
|
|
|
|||
|
|
@ -476,11 +476,12 @@ Usage: llm embed [OPTIONS] [COLLECTION] [ID]
|
|||
Embed text and store or return the result
|
||||
|
||||
Options:
|
||||
-i, --input FILENAME File to embed
|
||||
-i, --input PATH File to embed
|
||||
-m, --model TEXT Embedding model to use
|
||||
--store Store the text itself in the database
|
||||
-d, --database FILE
|
||||
-c, --content TEXT Content to embed
|
||||
--binary Treat input as binary data
|
||||
--metadata TEXT JSON object metadata to store
|
||||
-f, --format [json|blob|base64|hex]
|
||||
Output format
|
||||
|
|
@ -511,6 +512,7 @@ Options:
|
|||
--files <DIRECTORY TEXT>... Embed files in this directory - specify directory
|
||||
and glob pattern
|
||||
--encoding TEXT Encoding to use when reading --files
|
||||
--binary Treat --files as binary data
|
||||
--sql TEXT Read input using this SQL query
|
||||
--attach <TEXT FILE>... Additional databases to attach - specify alias
|
||||
and file path
|
||||
|
|
|
|||
47
llm/cli.py
47
llm/cli.py
|
|
@ -34,7 +34,7 @@ import sqlite_utils
|
|||
from sqlite_utils.utils import rows_from_file, Format
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional
|
||||
from typing import cast, Optional, Iterable, Union, Tuple
|
||||
import warnings
|
||||
import yaml
|
||||
|
||||
|
|
@ -1025,7 +1025,7 @@ def uninstall(packages, yes):
|
|||
@click.option(
|
||||
"-i",
|
||||
"--input",
|
||||
type=click.File("r"),
|
||||
type=click.Path(exists=True, readable=True, allow_dash=True),
|
||||
help="File to embed",
|
||||
)
|
||||
@click.option("-m", "--model", help="Embedding model to use")
|
||||
|
|
@ -1041,6 +1041,7 @@ def uninstall(packages, yes):
|
|||
"--content",
|
||||
help="Content to embed",
|
||||
)
|
||||
@click.option("--binary", is_flag=True, help="Treat input as binary data")
|
||||
@click.option(
|
||||
"--metadata",
|
||||
help="JSON object metadata to store",
|
||||
|
|
@ -1053,7 +1054,9 @@ def uninstall(packages, yes):
|
|||
type=click.Choice(["json", "blob", "base64", "hex"]),
|
||||
help="Output format",
|
||||
)
|
||||
def embed(collection, id, input, model, store, database, content, metadata, format_):
|
||||
def embed(
|
||||
collection, id, input, model, store, database, content, binary, metadata, format_
|
||||
):
|
||||
"""Embed text and store or return the result"""
|
||||
if collection and not id:
|
||||
raise click.ClickException("Must provide both collection and id")
|
||||
|
|
@ -1101,10 +1104,15 @@ def embed(collection, id, input, model, store, database, content, metadata, form
|
|||
|
||||
# Resolve input text
|
||||
if not content:
|
||||
if not input:
|
||||
if not input or input == "-":
|
||||
# Read from stdin
|
||||
input = sys.stdin
|
||||
content = input.read()
|
||||
input_source = sys.stdin.buffer if binary else sys.stdin
|
||||
content = input_source.read()
|
||||
else:
|
||||
mode = "rb" if binary else "r"
|
||||
with open(input, mode) as f:
|
||||
content = f.read()
|
||||
|
||||
if not content:
|
||||
raise click.ClickException("No content provided")
|
||||
|
||||
|
|
@ -1148,6 +1156,7 @@ def embed(collection, id, input, model, store, database, content, metadata, form
|
|||
help="Encoding to use when reading --files",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option("--binary", is_flag=True, help="Treat --files as binary data")
|
||||
@click.option("--sql", help="Read input using this SQL query")
|
||||
@click.option(
|
||||
"--attach",
|
||||
|
|
@ -1170,6 +1179,7 @@ def embed_multi(
|
|||
format,
|
||||
files,
|
||||
encodings,
|
||||
binary,
|
||||
sql,
|
||||
attach,
|
||||
prefix,
|
||||
|
|
@ -1193,6 +1203,10 @@ def embed_multi(
|
|||
2. A SQL query against a SQLite database
|
||||
3. A directory of files
|
||||
"""
|
||||
if binary and not files:
|
||||
raise click.UsageError("--binary must be used with --files")
|
||||
if binary and encodings:
|
||||
raise click.UsageError("--binary cannot be used with --encoding")
|
||||
if not input_path and not sql and not files:
|
||||
raise click.UsageError("Either --sql or input path or --files is required")
|
||||
|
||||
|
|
@ -1235,11 +1249,14 @@ def embed_multi(
|
|||
for path in pathlib.Path(directory).glob(pattern):
|
||||
relative = path.relative_to(directory)
|
||||
content = None
|
||||
for encoding in encodings:
|
||||
try:
|
||||
content = path.read_text(encoding=encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
if binary:
|
||||
content = path.read_bytes()
|
||||
else:
|
||||
for encoding in encodings:
|
||||
try:
|
||||
content = path.read_text(encoding=encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
if content is None:
|
||||
# Log to stderr
|
||||
click.echo(
|
||||
|
|
@ -1280,12 +1297,14 @@ def embed_multi(
|
|||
rows, label="Embedding", show_percent=True, length=expected_length
|
||||
) as rows:
|
||||
|
||||
def tuples():
|
||||
def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:
|
||||
for row in rows:
|
||||
values = list(row.values())
|
||||
id = prefix + str(values[0])
|
||||
text = " ".join(v or "" for v in values[1:])
|
||||
yield id, text
|
||||
if binary:
|
||||
yield id, cast(bytes, values[1])
|
||||
else:
|
||||
yield id, " ".join(v or "" for v in values[1:])
|
||||
|
||||
# collection_obj.max_batch_size = 1
|
||||
collection_obj.embed_multi(tuples(), store=store)
|
||||
|
|
|
|||
|
|
@ -67,9 +67,9 @@ class Ada002(EmbeddingModel):
|
|||
key_env_var = "OPENAI_API_KEY"
|
||||
batch_size = 100 # Maybe this should be 2048
|
||||
|
||||
def embed_batch(self, texts: Iterable[str]) -> Iterator[List[float]]:
|
||||
def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
|
||||
results = openai.Embedding.create(
|
||||
input=texts, model="text-embedding-ada-002", api_key=self.get_key()
|
||||
input=items, model="text-embedding-ada-002", api_key=self.get_key()
|
||||
)["data"]
|
||||
return ([float(r) for r in result["embedding"]] for result in results)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import json
|
|||
from sqlite_utils import Database
|
||||
from sqlite_utils.db import Table
|
||||
import time
|
||||
from typing import cast, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import cast, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -117,33 +117,34 @@ class Collection:
|
|||
def embed(
|
||||
self,
|
||||
id: str,
|
||||
text: str,
|
||||
value: Union[str, bytes],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
store: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Embed text and store it in the collection with a given ID.
|
||||
Embed value and store it in the collection with a given ID.
|
||||
|
||||
Args:
|
||||
id (str): ID for the text
|
||||
text (str): Text to be embedded
|
||||
id (str): ID for the value
|
||||
value (str or bytes): value to be embedded
|
||||
metadata (dict, optional): Metadata to be stored
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
store (bool, optional): Whether to store the value in the content or content_blob column
|
||||
"""
|
||||
from llm import encode
|
||||
|
||||
content_hash = self.content_hash(text)
|
||||
content_hash = self.content_hash(value)
|
||||
if self.db["embeddings"].count_where(
|
||||
"content_hash = ? and collection_id = ?", [content_hash, self.id]
|
||||
):
|
||||
return
|
||||
embedding = self.model().embed(text)
|
||||
embedding = self.model().embed(value)
|
||||
cast(Table, self.db["embeddings"]).insert(
|
||||
{
|
||||
"collection_id": self.id,
|
||||
"id": id,
|
||||
"embedding": encode(embedding),
|
||||
"content": text if store else None,
|
||||
"content": value if (store and isinstance(value, str)) else None,
|
||||
"content_blob": value if (store and isinstance(value, bytes)) else None,
|
||||
"content_hash": content_hash,
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
"updated": int(time.time()),
|
||||
|
|
@ -152,7 +153,7 @@ class Collection:
|
|||
)
|
||||
|
||||
def embed_multi(
|
||||
self, entries: Iterable[Tuple[str, str]], store: bool = False
|
||||
self, entries: Iterable[Tuple[str, Union[str, bytes]]], store: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Embed multiple texts and store them in the collection with given IDs.
|
||||
|
|
@ -162,20 +163,20 @@ class Collection:
|
|||
store (bool, optional): Whether to store the text in the content column
|
||||
"""
|
||||
self.embed_multi_with_metadata(
|
||||
((id, text, None) for id, text in entries), store=store
|
||||
((id, value, None) for id, value in entries), store=store
|
||||
)
|
||||
|
||||
def embed_multi_with_metadata(
|
||||
self,
|
||||
entries: Iterable[Tuple[str, str, Optional[Dict[str, Any]]]],
|
||||
entries: Iterable[Tuple[str, Union[str, bytes], Optional[Dict[str, Any]]]],
|
||||
store: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Embed multiple texts along with metadata and store them in the collection with given IDs.
|
||||
Embed multiple values along with metadata and store them in the collection with given IDs.
|
||||
|
||||
Args:
|
||||
entries (iterable): Iterable of (id: str, text: str, metadata: None or dict)
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
entries (iterable): Iterable of (id: str, value: str or bytes, metadata: None or dict)
|
||||
store (bool, optional): Whether to store the value in the content or content_blob column
|
||||
"""
|
||||
import llm
|
||||
|
||||
|
|
@ -215,12 +216,17 @@ class Collection:
|
|||
"collection_id": collection_id,
|
||||
"id": id,
|
||||
"embedding": llm.encode(embedding),
|
||||
"content": text if store else None,
|
||||
"content_hash": self.content_hash(text),
|
||||
"content": value
|
||||
if (store and isinstance(value, str))
|
||||
else None,
|
||||
"content_blob": value
|
||||
if (store and isinstance(value, bytes))
|
||||
else None,
|
||||
"content_hash": self.content_hash(value),
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
"updated": int(time.time()),
|
||||
}
|
||||
for (embedding, (id, text, metadata)) in zip(
|
||||
for (embedding, (id, value, metadata)) in zip(
|
||||
embeddings, filtered_batch
|
||||
)
|
||||
),
|
||||
|
|
@ -300,18 +306,18 @@ class Collection:
|
|||
comparison_vector = llm.decode(embedding)
|
||||
return self.similar_by_vector(comparison_vector, number, skip_id=id)
|
||||
|
||||
def similar(self, text: str, number: int = 10) -> List[Entry]:
|
||||
def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]:
|
||||
"""
|
||||
Find similar items in the collection by a given text.
|
||||
Find similar items in the collection by a given value.
|
||||
|
||||
Args:
|
||||
text (str): Text to search by
|
||||
value (str or bytes): value to search by
|
||||
number (int, optional): Number of similar items to return
|
||||
|
||||
Returns:
|
||||
list: List of Entry objects
|
||||
"""
|
||||
comparison_vector = self.model().embed(text)
|
||||
comparison_vector = self.model().embed(value)
|
||||
return self.similar_by_vector(comparison_vector, number)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -334,6 +340,8 @@ class Collection:
|
|||
self.db.execute("delete from collections where id = ?", [self.id])
|
||||
|
||||
@staticmethod
|
||||
def content_hash(text: str) -> bytes:
|
||||
def content_hash(input: Union[str, bytes]) -> bytes:
|
||||
"Hash content for deduplication. Override to change hashing behavior."
|
||||
return hashlib.md5(text.encode("utf8")).digest()
|
||||
if isinstance(input, str):
|
||||
input = input.encode("utf8")
|
||||
return hashlib.md5(input).digest()
|
||||
|
|
|
|||
|
|
@ -83,3 +83,11 @@ def m004_store_content_hash(db):
|
|||
# De-register functions
|
||||
db.conn.create_function("temp_md5", 1, None)
|
||||
db.conn.create_function("temp_random_md5", 0, None)
|
||||
|
||||
|
||||
@embeddings_migrations()
|
||||
def m005_add_content_blob(db):
|
||||
db["embeddings"].add_column("content_blob", bytes)
|
||||
db["embeddings"].transform(
|
||||
column_order=("collection_id", "id", "embedding", "content", "content_blob")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from .errors import NeedsKeyException
|
|||
from itertools import islice
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -291,29 +291,49 @@ class EmbeddingModel(ABC, _get_key_mixin):
|
|||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
|
||||
supports_text: bool = True
|
||||
supports_binary: bool = False
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"Embed a single text string, return a list of floats"
|
||||
return next(iter(self.embed_batch([text])))
|
||||
def _check(self, item: Union[str, bytes]):
|
||||
if not self.supports_binary and isinstance(item, bytes):
|
||||
raise ValueError(
|
||||
"This model does not support binary data, only text strings"
|
||||
)
|
||||
if not self.supports_text and isinstance(item, str):
|
||||
raise ValueError(
|
||||
"This model does not support text strings, only binary data"
|
||||
)
|
||||
|
||||
def embed_multi(self, texts: Iterable[str]) -> Iterator[List[float]]:
|
||||
"Embed multiple texts in batches according to the model batch_size"
|
||||
iter_texts = iter(texts)
|
||||
def embed(self, item: Union[str, bytes]) -> List[float]:
|
||||
"Embed a single text string or binary blob, return a list of floats"
|
||||
self._check(item)
|
||||
return next(iter(self.embed_batch([item])))
|
||||
|
||||
def embed_multi(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
|
||||
"Embed multiple items in batches according to the model batch_size"
|
||||
iter_items = iter(items)
|
||||
if (not self.supports_binary) or (not self.supports_text):
|
||||
|
||||
def checking_iter(items):
|
||||
for item in items:
|
||||
self._check(item)
|
||||
yield item
|
||||
|
||||
iter_items = checking_iter(items)
|
||||
if self.batch_size is None:
|
||||
yield from self.embed_batch(iter_texts)
|
||||
yield from self.embed_batch(iter_items)
|
||||
return
|
||||
while True:
|
||||
batch_texts = list(islice(iter_texts, self.batch_size))
|
||||
if not batch_texts:
|
||||
batch_items = list(islice(iter_items, self.batch_size))
|
||||
if not batch_items:
|
||||
break
|
||||
yield from self.embed_batch(batch_texts)
|
||||
yield from self.embed_batch(batch_items)
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: Iterable[str]) -> Iterator[List[float]]:
|
||||
def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float]]:
|
||||
"""
|
||||
Embed a batch of text strings, return a list of lists of floats
|
||||
Embed a batch of strings or blobs, return a list of lists of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ class MockModel(llm.Model):
|
|||
class EmbedDemo(llm.EmbeddingModel):
|
||||
model_id = "embed-demo"
|
||||
batch_size = 10
|
||||
supports_binary = True
|
||||
|
||||
def __init__(self):
|
||||
self.embedded_content = []
|
||||
|
|
@ -92,6 +93,18 @@ class EmbedDemo(llm.EmbeddingModel):
|
|||
yield embedding
|
||||
|
||||
|
||||
class EmbedBinaryOnly(EmbedDemo):
|
||||
model_id = "embed-binary-only"
|
||||
supports_text = False
|
||||
supports_binary = True
|
||||
|
||||
|
||||
class EmbedTextOnly(EmbedDemo):
|
||||
model_id = "embed-text-only"
|
||||
supports_text = True
|
||||
supports_binary = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embed_demo():
|
||||
return EmbedDemo()
|
||||
|
|
@ -110,6 +123,8 @@ def register_embed_demo_model(embed_demo, mock_model):
|
|||
@llm.hookimpl
|
||||
def register_embedding_models(self, register):
|
||||
register(embed_demo)
|
||||
register(EmbedBinaryOnly())
|
||||
register(EmbedTextOnly())
|
||||
|
||||
@llm.hookimpl
|
||||
def register_models(self, register):
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ def test_collection(collection):
|
|||
"id": "1",
|
||||
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
"content": None,
|
||||
"content_blob": None,
|
||||
"content_hash": collection.content_hash("hello world"),
|
||||
"metadata": None,
|
||||
"updated": ANY,
|
||||
|
|
@ -66,6 +67,7 @@ def test_collection(collection):
|
|||
"id": "2",
|
||||
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
"content": None,
|
||||
"content_blob": None,
|
||||
"content_hash": collection.content_hash("goodbye world"),
|
||||
"metadata": None,
|
||||
"updated": ANY,
|
||||
|
|
@ -121,3 +123,35 @@ def test_collection_delete(collection):
|
|||
collection.delete()
|
||||
assert db["embeddings"].count == 0
|
||||
assert db["collections"].count == 0
|
||||
|
||||
|
||||
def test_binary_only_and_text_only_embedding_models():
|
||||
binary_only = llm.get_embedding_model("embed-binary-only")
|
||||
text_only = llm.get_embedding_model("embed-text-only")
|
||||
|
||||
assert binary_only.supports_binary
|
||||
assert not binary_only.supports_text
|
||||
assert not text_only.supports_binary
|
||||
assert text_only.supports_text
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
binary_only.embed("hello world")
|
||||
|
||||
binary_only.embed(b"hello world")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
text_only.embed(b"hello world")
|
||||
|
||||
text_only.embed("hello world")
|
||||
|
||||
# Try the multi versions too
|
||||
# Have to call list() on this or the generator is not evaluated
|
||||
with pytest.raises(ValueError):
|
||||
list(binary_only.embed_multi(["hello world"]))
|
||||
|
||||
list(binary_only.embed_multi([b"hello world"]))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(text_only.embed_multi([b"hello world"]))
|
||||
|
||||
list(text_only.embed_multi(["hello world"]))
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ def test_embed_store(user_path, metadata, metadata_error):
|
|||
b"\x00\x00\x00\x00\x00\x00\x00"
|
||||
),
|
||||
"content": None,
|
||||
"content_blob": None,
|
||||
"content_hash": Collection.content_hash("hello"),
|
||||
"metadata": expected_metadata,
|
||||
"updated": ANY,
|
||||
|
|
@ -146,6 +147,32 @@ def test_embed_store(user_path, metadata, metadata_error):
|
|||
assert db["embeddings"].count == 0
|
||||
|
||||
|
||||
def test_embed_store_binary(user_path):
|
||||
runner = CliRunner()
|
||||
args = ["embed", "-m", "embed-demo", "items", "2", "--binary", "--store"]
|
||||
result = runner.invoke(cli, args, input=b"\x00\x01\x02")
|
||||
assert result.exit_code == 0
|
||||
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
|
||||
rows = list(db["embeddings"].rows)
|
||||
assert rows == [
|
||||
{
|
||||
"collection_id": 1,
|
||||
"id": "2",
|
||||
"embedding": (
|
||||
b"\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
),
|
||||
"content": None,
|
||||
"content_blob": b"\x00\x01\x02",
|
||||
"content_hash": b'\xb9_g\xf6\x1e\xbb\x03a\x96"\xd7\x98\xf4_\xc2\xd3',
|
||||
"metadata": None,
|
||||
"updated": ANY,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_collection_delete_errors(user_path):
|
||||
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
|
||||
collection = Collection("items", db, model_id="embed-demo")
|
||||
|
|
@ -254,7 +281,7 @@ def test_embed_multi_file_input(tmpdir, use_stdin, prefix, filename, content):
|
|||
if filename.endswith(".jsonl"):
|
||||
args.extend(("--format", "nl"))
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args, input=input)
|
||||
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
# Check that everything was embedded correctly
|
||||
db = sqlite_utils.Database(str(db_path))
|
||||
|
|
@ -266,6 +293,35 @@ def test_embed_multi_file_input(tmpdir, use_stdin, prefix, filename, content):
|
|||
assert ids == expected_ids
|
||||
|
||||
|
||||
def test_embed_multi_files_binary_store(tmpdir):
|
||||
db_path = tmpdir / "embeddings.db"
|
||||
args = ["embed-multi", "binfiles", "-d", str(db_path), "-m", "embed-demo"]
|
||||
bin_path = tmpdir / "file.bin"
|
||||
bin_path.write(b"\x00\x01\x02")
|
||||
args.extend(("--files", str(tmpdir), "*.bin", "--store", "--binary"))
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args, catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
db = sqlite_utils.Database(str(db_path))
|
||||
assert db["embeddings"].count == 1
|
||||
row = list(db["embeddings"].rows)[0]
|
||||
assert row == {
|
||||
"collection_id": 1,
|
||||
"id": "file.bin",
|
||||
"embedding": (
|
||||
b"\x00\x00@@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
),
|
||||
"content": None,
|
||||
"content_blob": b"\x00\x01\x02",
|
||||
"content_hash": b'\xb9_g\xf6\x1e\xbb\x03a\x96"\xd7\x98\xf4_\xc2\xd3',
|
||||
"metadata": None,
|
||||
"updated": ANY,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_other_db", (True, False))
|
||||
@pytest.mark.parametrize("prefix", (None, "prefix"))
|
||||
def test_embed_multi_sql(tmpdir, use_other_db, prefix):
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ def test_migrations_for_embeddings():
|
|||
"id": str,
|
||||
"embedding": bytes,
|
||||
"content": str,
|
||||
"content_blob": bytes,
|
||||
"content_hash": bytes,
|
||||
"metadata": str,
|
||||
"updated": int,
|
||||
|
|
|
|||
Loading…
Reference in a new issue