From 3b2d5bf7f9b3ae8628e7fee628b8bd044a4e437b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 1 Sep 2023 20:15:28 -0700 Subject: [PATCH] embed_multi and embed_multi_with_metadata, closes #202 --- docs/embeddings/python-api.md | 28 +++++++++++++++++++++ llm/embeddings.py | 47 +++++++++++++++++++++++++++++------ tests/test_embed.py | 23 +++++++++++++++++ 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/docs/embeddings/python-api.md b/docs/embeddings/python-api.md index 9221e54..00e3368 100644 --- a/docs/embeddings/python-api.md +++ b/docs/embeddings/python-api.md @@ -56,6 +56,34 @@ collection.embed("hound", "my happy hound", metadata={"name": "Hound"}, store=Tr ``` This additional metadata will be stored as JSON in the `metadata` column of the embeddings database table. +(embeddings-python-bulk)= +### Storing embeddings in bulk + +The `collection.embed_multi()` method can be used to store embeddings for multiple strings at once. This can be more efficient for some embedding models. + +```python +collection.embed_multi( + [ + ("hound", "my happy hound"), + ("cat", "my dissatisfied cat"), + ], + # Add this to store the strings in the content column: + store=True, +) +``` +To include metadata to be stored with each item, call `embed_multi_with_metadata()`: + +```python +collection.embed_multi_with_metadata( + [ + ("hound", "my happy hound", {"name": "Hound"}), + ("cat", "my dissatisfied cat", {"name": "Cat"}), + ], + # This can also take the store=True argument: + store=True, +) +``` + (embeddings-python-similar)= ## Retrieving similar items diff --git a/llm/embeddings.py b/llm/embeddings.py index 05f0aee..0a2820b 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -1,10 +1,11 @@ from .models import EmbeddingModel from .embeddings_migrations import embeddings_migrations from dataclasses import dataclass +from itertools import islice import json from sqlite_utils import Database from sqlite_utils.db import Table -from typing import cast, Any, Dict, List, Tuple, Optional, Union +from typing import cast, Any, Dict, Iterable, List, Tuple, Optional, Union @dataclass @@ -16,6 +17,8 @@ class Entry: class Collection: + max_batch_size: int = 100 + def __init__( self, db: Database, @@ -136,27 +139,57 @@ class Collection: } ) - def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None: + def embed_multi( + self, entries: Iterable[Union[str, str]], store: bool = False + ) -> None: """ Embed multiple texts and store them in the collection with given IDs. Args: - id_text_map (dict): Dictionary mapping IDs to texts + entries (iterable): Iterable of (id: str, text: str) tuples store (bool, optional): Whether to store the text in the content column """ - raise NotImplementedError + self.embed_multi_with_metadata( + ((id, text, None) for id, text in entries), store=store + ) def embed_multi_with_metadata( self, - id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]], + entries: Iterable[Union[str, str, Optional[Dict[str, Any]]]], + store: bool = False, ) -> None: """ Embed multiple texts along with metadata and store them in the collection with given IDs. Args: - id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples + entries (iterable): Iterable of (id: str, text: str, metadata: None or dict) + store (bool, optional): Whether to store the text in the content column """ - raise NotImplementedError + import llm + + batch_size = min( + self.max_batch_size, (self.model().batch_size or self.max_batch_size) + ) + iterator = iter(entries) + collection_id = self.id() + while True: + batch = list(islice(iterator, batch_size)) + if not batch: + break + embeddings = list(self.model().embed_multi(item[1] for item in batch)) + with self.db.conn: + cast(Table, self.db["embeddings"]).insert_all( + ( + { + "collection_id": collection_id, + "id": id, + "embedding": llm.encode(embedding), + "content": text if store else None, + "metadata": json.dumps(metadata) if metadata else None, + } + for (embedding, (id, text, metadata)) in zip(embeddings, batch) + ) + ) def similar_by_vector( self, vector: List[float], number: int = 10, skip_id: Optional[str] = None diff --git a/tests/test_embed.py b/tests/test_embed.py index d7f2537..502f006 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -90,3 +90,26 @@ def test_similar_by_id(collection): assert results == [ Entry(id="2", score=pytest.approx(0.9863939238321437)), ] + + +@pytest.mark.parametrize("with_metadata", (False, True)) +def test_embed_multi(with_metadata): + db = sqlite_utils.Database(memory=True) + collection = llm.Collection(db, "test", model_id="embed-demo") + ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000)) + if with_metadata: + ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts) + collection.embed_multi_with_metadata(ids_and_texts) + else: + # Exercise store=True here too + collection.embed_multi(ids_and_texts, store=True) + rows = list(db["embeddings"].rows) + assert len(rows) == 1000 + rows_with_metadata = [row for row in rows if row["metadata"] is not None] + rows_with_content = [row for row in rows if row["content"] is not None] + if with_metadata: + assert len(rows_with_metadata) == 1000 + assert len(rows_with_content) == 0 + else: + assert len(rows_with_metadata) == 0 + assert len(rows_with_content) == 1000