mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-27 09:50:23 +00:00
embed_multi and embed_multi_with_metadata, closes #202
This commit is contained in:
parent
4be89facb5
commit
3b2d5bf7f9
3 changed files with 91 additions and 7 deletions
|
|
@ -56,6 +56,34 @@ collection.embed("hound", "my happy hound", metadata={"name": "Hound"}, store=Tr
|
|||
```
|
||||
This additional metadata will be stored as JSON in the `metadata` column of the embeddings database table.
|
||||
|
||||
(embeddings-python-bulk)=
|
||||
### Storing embeddings in bulk
|
||||
|
||||
The `collection.embed_multi()` method can be used to store embeddings for multiple strings at once. This can be more efficient for some embedding models.
|
||||
|
||||
```python
|
||||
collection.embed_multi(
|
||||
[
|
||||
("hound", "my happy hound"),
|
||||
("cat", "my dissatisfied cat"),
|
||||
],
|
||||
# Add this to store the strings in the content column:
|
||||
store=True,
|
||||
)
|
||||
```
|
||||
To include metadata to be stored with each item, call `embed_multi_with_metadata()`:
|
||||
|
||||
```python
|
||||
collection.embed_multi_with_metadata(
|
||||
[
|
||||
("hound", "my happy hound", {"name": "Hound"}),
|
||||
("cat", "my dissatisfied cat", {"name": "Cat"}),
|
||||
],
|
||||
# This can also take the store=True argument:
|
||||
store=True,
|
||||
)
|
||||
```
|
||||
|
||||
(embeddings-python-similar)=
|
||||
## Retrieving similar items
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from .models import EmbeddingModel
|
||||
from .embeddings_migrations import embeddings_migrations
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
import json
|
||||
from sqlite_utils import Database
|
||||
from sqlite_utils.db import Table
|
||||
from typing import cast, Any, Dict, List, Tuple, Optional, Union
|
||||
from typing import cast, Any, Dict, Iterable, List, Tuple, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -16,6 +17,8 @@ class Entry:
|
|||
|
||||
|
||||
class Collection:
|
||||
max_batch_size: int = 100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
|
|
@ -136,27 +139,57 @@ class Collection:
|
|||
}
|
||||
)
|
||||
|
||||
def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
|
||||
def embed_multi(
|
||||
self, entries: Iterable[Union[str, str]], store: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Embed multiple texts and store them in the collection with given IDs.
|
||||
|
||||
Args:
|
||||
id_text_map (dict): Dictionary mapping IDs to texts
|
||||
entries (iterable): Iterable of (id: str, text: str) tuples
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
"""
|
||||
raise NotImplementedError
|
||||
self.embed_multi_with_metadata(
|
||||
((id, text, None) for id, text in entries), store=store
|
||||
)
|
||||
|
||||
def embed_multi_with_metadata(
|
||||
self,
|
||||
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
|
||||
entries: Iterable[Union[str, str, Optional[Dict[str, Any]]]],
|
||||
store: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Embed multiple texts along with metadata and store them in the collection with given IDs.
|
||||
|
||||
Args:
|
||||
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
|
||||
entries (iterable): Iterable of (id: str, text: str, metadata: None or dict)
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
"""
|
||||
raise NotImplementedError
|
||||
import llm
|
||||
|
||||
batch_size = min(
|
||||
self.max_batch_size, (self.model().batch_size or self.max_batch_size)
|
||||
)
|
||||
iterator = iter(entries)
|
||||
collection_id = self.id()
|
||||
while True:
|
||||
batch = list(islice(iterator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
embeddings = list(self.model().embed_multi(item[1] for item in batch))
|
||||
with self.db.conn:
|
||||
cast(Table, self.db["embeddings"]).insert_all(
|
||||
(
|
||||
{
|
||||
"collection_id": collection_id,
|
||||
"id": id,
|
||||
"embedding": llm.encode(embedding),
|
||||
"content": text if store else None,
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
}
|
||||
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
|
||||
)
|
||||
)
|
||||
|
||||
def similar_by_vector(
|
||||
self, vector: List[float], number: int = 10, skip_id: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -90,3 +90,26 @@ def test_similar_by_id(collection):
|
|||
assert results == [
|
||||
Entry(id="2", score=pytest.approx(0.9863939238321437)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_metadata", (False, True))
|
||||
def test_embed_multi(with_metadata):
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
collection = llm.Collection(db, "test", model_id="embed-demo")
|
||||
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
|
||||
if with_metadata:
|
||||
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)
|
||||
collection.embed_multi_with_metadata(ids_and_texts)
|
||||
else:
|
||||
# Exercise store=True here too
|
||||
collection.embed_multi(ids_and_texts, store=True)
|
||||
rows = list(db["embeddings"].rows)
|
||||
assert len(rows) == 1000
|
||||
rows_with_metadata = [row for row in rows if row["metadata"] is not None]
|
||||
rows_with_content = [row for row in rows if row["content"] is not None]
|
||||
if with_metadata:
|
||||
assert len(rows_with_metadata) == 1000
|
||||
assert len(rows_with_content) == 0
|
||||
else:
|
||||
assert len(rows_with_metadata) == 0
|
||||
assert len(rows_with_content) == 1000
|
||||
|
|
|
|||
Loading…
Reference in a new issue