mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-19 12:41:13 +00:00
Duplicate content is only embedded once, closes #217
This commit is contained in:
parent
0eda99e91c
commit
3bf781fba2
4 changed files with 71 additions and 5 deletions
|
|
@ -40,6 +40,8 @@ Embeddings are much more useful if you store them somewhere, so you can calculat
|
|||
|
||||
LLM includes the concept of a "collection" of embeddings. A collection groups together a set of stored embeddings created using the same model, each with a unique ID within that collection.
|
||||
|
||||
Embeddings also store a hash of the content that was embedded. This hash is later used to avoid calculating duplicate embeddings for the same content.
|
||||
|
||||
First, we'll set a default model so we don't have to keep repeating it:
|
||||
```bash
|
||||
llm embed-models default ada-002
|
||||
|
|
|
|||
|
|
@ -132,6 +132,11 @@ class Collection:
|
|||
"""
|
||||
from llm import encode
|
||||
|
||||
content_hash = self.content_hash(text)
|
||||
if self.db["embeddings"].count_where(
|
||||
"content_hash = ? and collection_id = ?", [content_hash, self.id]
|
||||
):
|
||||
return
|
||||
embedding = self.model().embed(text)
|
||||
cast(Table, self.db["embeddings"]).insert(
|
||||
{
|
||||
|
|
@ -139,7 +144,7 @@ class Collection:
|
|||
"id": id,
|
||||
"embedding": encode(embedding),
|
||||
"content": text if store else None,
|
||||
"content_hash": self.content_hash(text),
|
||||
"content_hash": content_hash,
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
"updated": int(time.time()),
|
||||
},
|
||||
|
|
@ -183,7 +188,26 @@ class Collection:
|
|||
batch = list(islice(iterator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
embeddings = list(self.model().embed_multi(item[1] for item in batch))
|
||||
# Calculate hashes first
|
||||
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
|
||||
# Any of those hashes already exist?
|
||||
existing_ids = [
|
||||
row["id"]
|
||||
for row in self.db.query(
|
||||
"""
|
||||
select id from embeddings
|
||||
where collection_id = ? and content_hash in ({})
|
||||
""".format(
|
||||
",".join("?" for _ in items_and_hashes)
|
||||
),
|
||||
[collection_id]
|
||||
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
|
||||
)
|
||||
]
|
||||
filtered_batch = [item for item in batch if item[0] not in existing_ids]
|
||||
embeddings = list(
|
||||
self.model().embed_multi(item[1] for item in filtered_batch)
|
||||
)
|
||||
with self.db.conn:
|
||||
cast(Table, self.db["embeddings"]).insert_all(
|
||||
(
|
||||
|
|
@ -196,7 +220,9 @@ class Collection:
|
|||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
"updated": int(time.time()),
|
||||
}
|
||||
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
|
||||
for (embedding, (id, text, metadata)) in zip(
|
||||
embeddings, filtered_batch
|
||||
)
|
||||
),
|
||||
replace=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -42,11 +42,15 @@ class EmbedDemo(llm.EmbeddingModel):
|
|||
model_id = "embed-demo"
|
||||
batch_size = 10
|
||||
|
||||
def __init__(self):
|
||||
self.embedded_content = []
|
||||
|
||||
def embed_batch(self, texts):
|
||||
if not hasattr(self, "batch_count"):
|
||||
self.batch_count = 0
|
||||
self.batch_count += 1
|
||||
for text in texts:
|
||||
self.embedded_content.append(text)
|
||||
words = text.split()[:16]
|
||||
embedding = [len(word) for word in words]
|
||||
# Pad with 0 up to 16 words
|
||||
|
|
@ -54,14 +58,19 @@ class EmbedDemo(llm.EmbeddingModel):
|
|||
yield embedding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embed_demo():
|
||||
return EmbedDemo()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def register_embed_demo_model():
|
||||
def register_embed_demo_model(embed_demo):
|
||||
class EmbedDemoPlugin:
|
||||
__name__ = "EmbedDemoPlugin"
|
||||
|
||||
@llm.hookimpl
|
||||
def register_embedding_models(self, register):
|
||||
register(EmbedDemo())
|
||||
register(embed_demo)
|
||||
|
||||
pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -420,3 +420,32 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
|
|||
# At the end of this, there should be 2 embeddings
|
||||
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
|
||||
assert db["embeddings"].count == 1
|
||||
|
||||
|
||||
def test_duplicate_content_embedded_only_once(embed_demo):
|
||||
# content_hash should avoid embedding the same content twice
|
||||
# per collection
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
assert len(embed_demo.embedded_content) == 0
|
||||
collection = Collection("test", db, model_id="embed-demo")
|
||||
collection.embed("1", "hello world")
|
||||
assert len(embed_demo.embedded_content) == 1
|
||||
collection.embed("2", "goodbye world")
|
||||
assert db["embeddings"].count == 2
|
||||
assert len(embed_demo.embedded_content) == 2
|
||||
collection.embed("1", "hello world")
|
||||
assert db["embeddings"].count == 2
|
||||
assert len(embed_demo.embedded_content) == 2
|
||||
# The same string in another collection should be embedded
|
||||
c2 = Collection("test2", db, model_id="embed-demo")
|
||||
c2.embed("1", "hello world")
|
||||
assert db["embeddings"].count == 3
|
||||
assert len(embed_demo.embedded_content) == 3
|
||||
|
||||
# Same again for embed_multi
|
||||
collection.embed_multi(
|
||||
(("1", "hello world"), ("2", "goodbye world"), ("3", "this is new"))
|
||||
)
|
||||
# Should have only embedded one more thing
|
||||
assert db["embeddings"].count == 4
|
||||
assert len(embed_demo.embedded_content) == 4
|
||||
|
|
|
|||
Loading…
Reference in a new issue