Duplicate content is only embedded once, closes #217

This commit is contained in:
Simon Willison 2023-09-03 17:39:11 -07:00
parent 0eda99e91c
commit 3bf781fba2
4 changed files with 71 additions and 5 deletions

View file

@ -40,6 +40,8 @@ Embeddings are much more useful if you store them somewhere, so you can calculat
LLM includes the concept of a "collection" of embeddings. A collection groups together a set of stored embeddings created using the same model, each with a unique ID within that collection.
Embeddings also store a hash of the content that was embedded. This hash is later used to avoid calculating duplicate embeddings for the same content.
First, we'll set a default model so we don't have to keep repeating it:
```bash
llm embed-models default ada-002

View file

@ -132,6 +132,11 @@ class Collection:
"""
from llm import encode
content_hash = self.content_hash(text)
if self.db["embeddings"].count_where(
"content_hash = ? and collection_id = ?", [content_hash, self.id]
):
return
embedding = self.model().embed(text)
cast(Table, self.db["embeddings"]).insert(
{
@ -139,7 +144,7 @@ class Collection:
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"content_hash": self.content_hash(text),
"content_hash": content_hash,
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
},
@ -183,7 +188,26 @@ class Collection:
batch = list(islice(iterator, batch_size))
if not batch:
break
embeddings = list(self.model().embed_multi(item[1] for item in batch))
# Calculate hashes first
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
# Any of those hashes already exist?
existing_ids = [
row["id"]
for row in self.db.query(
"""
select id from embeddings
where collection_id = ? and content_hash in ({})
""".format(
",".join("?" for _ in items_and_hashes)
),
[collection_id]
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
)
]
filtered_batch = [item for item in batch if item[0] not in existing_ids]
embeddings = list(
self.model().embed_multi(item[1] for item in filtered_batch)
)
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
@ -196,7 +220,9 @@ class Collection:
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
}
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
for (embedding, (id, text, metadata)) in zip(
embeddings, filtered_batch
)
),
replace=True,
)

View file

@ -42,11 +42,15 @@ class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
def __init__(self):
self.embedded_content = []
def embed_batch(self, texts):
if not hasattr(self, "batch_count"):
self.batch_count = 0
self.batch_count += 1
for text in texts:
self.embedded_content.append(text)
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
@ -54,14 +58,19 @@ class EmbedDemo(llm.EmbeddingModel):
yield embedding
@pytest.fixture
def embed_demo():
return EmbedDemo()
@pytest.fixture(autouse=True)
def register_embed_demo_model():
def register_embed_demo_model(embed_demo):
class EmbedDemoPlugin:
__name__ = "EmbedDemoPlugin"
@llm.hookimpl
def register_embedding_models(self, register):
register(EmbedDemo())
register(embed_demo)
pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin")
try:

View file

@ -420,3 +420,32 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
# At the end of this, there should be 2 embeddings
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
assert db["embeddings"].count == 1
def test_duplicate_content_embedded_only_once(embed_demo):
# content_hash should avoid embedding the same content twice
# per collection
db = sqlite_utils.Database(memory=True)
assert len(embed_demo.embedded_content) == 0
collection = Collection("test", db, model_id="embed-demo")
collection.embed("1", "hello world")
assert len(embed_demo.embedded_content) == 1
collection.embed("2", "goodbye world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
collection.embed("1", "hello world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
# The same string in another collection should be embedded
c2 = Collection("test2", db, model_id="embed-demo")
c2.embed("1", "hello world")
assert db["embeddings"].count == 3
assert len(embed_demo.embedded_content) == 3
# Same again for embed_multi
collection.embed_multi(
(("1", "hello world"), ("2", "goodbye world"), ("3", "this is new"))
)
# Should have only embedded one more thing
assert db["embeddings"].count == 4
assert len(embed_demo.embedded_content) == 4