Reuse embeddings for hashed content, --store now works on second run - closes #224

This commit is contained in:
Simon Willison 2023-09-11 22:44:22 -07:00
parent 52cec1304b
commit 267e2ea999
2 changed files with 87 additions and 22 deletions

View file

@ -130,14 +130,20 @@ class Collection:
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the value in the content or content_blob column
"""
from llm import encode
from llm import decode, encode
content_hash = self.content_hash(value)
if self.db["embeddings"].count_where(
"content_hash = ? and collection_id = ?", [content_hash, self.id]
):
return
embedding = self.model().embed(value)
existing = list(
self.db["embeddings"].rows_where(
"content_hash = ?", [content_hash], limit=1
)
)
if existing:
# Reuse the embedding from whatever record this is, it might even
# be in a different collection
embedding = decode(existing[0]["embedding"])
else:
embedding = self.model().embed(value)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id,
@ -192,23 +198,39 @@ class Collection:
# Calculate hashes first
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
# Any of those hashes already exist?
existing_ids = [
row["id"]
hashes_and_embeddings = {
row["content_hash"]: row["embedding"]
for row in self.db.query(
"""
select id from embeddings
where collection_id = ? and content_hash in ({})
select content_hash, embedding from embeddings
where content_hash in ({})
""".format(
",".join("?" for _ in items_and_hashes)
),
[collection_id]
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
[item_and_hash[1] for item_and_hash in items_and_hashes],
)
}
# We need to embed the ones that don't exist yet
items_to_embed = [
item_and_hash[0]
for item_and_hash in items_and_hashes
if item_and_hash[1] not in hashes_and_embeddings
]
filtered_batch = [item for item in batch if item[0] not in existing_ids]
embeddings = list(
self.model().embed_multi(item[1] for item in filtered_batch)
calculated_embeddings = list(
self.model().embed_multi(item[1] for item in items_to_embed)
)
# This should be a list of all the embeddings from both sources
embeddings = []
calculated_embeddings_iterator = iter(calculated_embeddings)
for item_and_hash in items_and_hashes:
if item_and_hash[1] in hashes_and_embeddings:
embeddings.append(
llm.decode(hashes_and_embeddings[item_and_hash[1]])
)
else:
embeddings.append(next(calculated_embeddings_iterator))
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
@ -226,9 +248,7 @@ class Collection:
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
}
for (embedding, (id, value, metadata)) in zip(
embeddings, filtered_batch
)
for (embedding, (id, value, metadata)) in zip(embeddings, batch)
),
replace=True,
)

View file

@ -546,7 +546,6 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
def test_duplicate_content_embedded_only_once(embed_demo):
# content_hash should avoid embedding the same content twice
# per collection
db = sqlite_utils.Database(memory=True)
assert len(embed_demo.embedded_content) == 0
collection = Collection("test", db, model_id="embed-demo")
@ -558,11 +557,11 @@ def test_duplicate_content_embedded_only_once(embed_demo):
collection.embed("1", "hello world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
# The same string in another collection should be embedded
# The same string in another collection should also not be embedded
c2 = Collection("test2", db, model_id="embed-demo")
c2.embed("1", "hello world")
assert db["embeddings"].count == 3
assert len(embed_demo.embedded_content) == 3
assert len(embed_demo.embedded_content) == 2
# Same again for embed_multi
collection.embed_multi(
@ -570,4 +569,50 @@ def test_duplicate_content_embedded_only_once(embed_demo):
)
# Should have only embedded one more thing
assert db["embeddings"].count == 4
assert len(embed_demo.embedded_content) == 4
assert len(embed_demo.embedded_content) == 3
def test_embed_again_with_store(embed_demo, multi_files):
# https://github.com/simonw/llm/issues/224
db_path, files = multi_files
runner = CliRunner(mix_stderr=False)
# First embed all the files without storing them
args = [
"embed-multi",
"files",
"-d",
db_path,
"-m",
"embed-demo",
"--files",
str(files),
"**/*.txt",
]
result = runner.invoke(cli, args)
assert result.exit_code == 0
assert not result.stderr
embeddings_db = sqlite_utils.Database(db_path)
rows = list(embeddings_db.query("select id, content from embeddings"))
assert rows == [
{"id": "file2.txt", "content": None},
{"id": "file1.txt", "content": None},
{"id": "nested/two.txt", "content": None},
{"id": "nested/one.txt", "content": None},
{"id": "nested/more/three.txt", "content": None},
]
assert len(embed_demo.embedded_content) == 5
# Now we run it again with the --store option
result2 = runner.invoke(cli, args + ["--store"])
assert result2.exit_code == 0
assert not result2.stderr
# The rows should have their content now
rows = list(embeddings_db.query("select id, content from embeddings"))
assert rows == [
{"id": "file2.txt", "content": "goodbye world"},
{"id": "file1.txt", "content": "hello world"},
{"id": "nested/two.txt", "content": "two"},
{"id": "nested/one.txt", "content": "one"},
{"id": "nested/more/three.txt", "content": "three"},
]
# But it should not have run any more embedding tasks
assert len(embed_demo.embedded_content) == 5