embed_multi and embed_multi_with_metadata, closes #202

This commit is contained in:
Simon Willison 2023-09-01 20:15:28 -07:00
parent 4be89facb5
commit 3b2d5bf7f9
3 changed files with 91 additions and 7 deletions

View file

@ -56,6 +56,34 @@ collection.embed("hound", "my happy hound", metadata={"name": "Hound"}, store=Tr
```
This additional metadata will be stored as JSON in the `metadata` column of the embeddings database table.
(embeddings-python-bulk)=
### Storing embeddings in bulk
The `collection.embed_multi()` method can be used to store embeddings for multiple strings at once. This can be more efficient for some embedding models.
```python
collection.embed_multi(
[
("hound", "my happy hound"),
("cat", "my dissatisfied cat"),
],
# Add this to store the strings in the content column:
store=True,
)
```
To include metadata to be stored with each item, call `embed_multi_with_metadata()`:
```python
collection.embed_multi_with_metadata(
[
("hound", "my happy hound", {"name": "Hound"}),
("cat", "my dissatisfied cat", {"name": "Cat"}),
],
# This can also take the store=True argument:
store=True,
)
```
(embeddings-python-similar)=
## Retrieving similar items

View file

@ -1,10 +1,11 @@
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
from dataclasses import dataclass
from itertools import islice
import json
from sqlite_utils import Database
from sqlite_utils.db import Table
from typing import cast, Any, Dict, List, Tuple, Optional, Union
from typing import cast, Any, Dict, Iterable, List, Tuple, Optional, Union
@dataclass
@ -16,6 +17,8 @@ class Entry:
class Collection:
max_batch_size: int = 100
def __init__(
self,
db: Database,
@ -136,27 +139,57 @@ class Collection:
}
)
def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
def embed_multi(
self, entries: Iterable[Union[str, str]], store: bool = False
) -> None:
"""
Embed multiple texts and store them in the collection with given IDs.
Args:
id_text_map (dict): Dictionary mapping IDs to texts
entries (iterable): Iterable of (id: str, text: str) tuples
store (bool, optional): Whether to store the text in the content column
"""
raise NotImplementedError
self.embed_multi_with_metadata(
((id, text, None) for id, text in entries), store=store
)
def embed_multi_with_metadata(
self,
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
entries: Iterable[Union[str, str, Optional[Dict[str, Any]]]],
store: bool = False,
) -> None:
"""
Embed multiple texts along with metadata and store them in the collection with given IDs.
Args:
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
entries (iterable): Iterable of (id: str, text: str, metadata: None or dict)
store (bool, optional): Whether to store the text in the content column
"""
raise NotImplementedError
import llm
batch_size = min(
self.max_batch_size, (self.model().batch_size or self.max_batch_size)
)
iterator = iter(entries)
collection_id = self.id()
while True:
batch = list(islice(iterator, batch_size))
if not batch:
break
embeddings = list(self.model().embed_multi(item[1] for item in batch))
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
{
"collection_id": collection_id,
"id": id,
"embedding": llm.encode(embedding),
"content": text if store else None,
"metadata": json.dumps(metadata) if metadata else None,
}
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
)
)
def similar_by_vector(
self, vector: List[float], number: int = 10, skip_id: Optional[str] = None

View file

@ -90,3 +90,26 @@ def test_similar_by_id(collection):
assert results == [
Entry(id="2", score=pytest.approx(0.9863939238321437)),
]
@pytest.mark.parametrize("with_metadata", (False, True))
def test_embed_multi(with_metadata):
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
if with_metadata:
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)
collection.embed_multi_with_metadata(ids_and_texts)
else:
# Exercise store=True here too
collection.embed_multi(ids_and_texts, store=True)
rows = list(db["embeddings"].rows)
assert len(rows) == 1000
rows_with_metadata = [row for row in rows if row["metadata"] is not None]
rows_with_content = [row for row in rows if row["content"] is not None]
if with_metadata:
assert len(rows_with_metadata) == 1000
assert len(rows_with_content) == 0
else:
assert len(rows_with_metadata) == 0
assert len(rows_with_content) == 1000