Revert "Reuse embeddings for hashed content, --store now works on second run - closes #224"

This reverts commit 267e2ea999.

It's broken, see:

https://github.com/simonw/llm/issues/224#issuecomment-1715014393
This commit is contained in:
Simon Willison 2023-09-11 22:56:52 -07:00
parent 267e2ea999
commit 591ad6f571
2 changed files with 22 additions and 87 deletions

View file

@ -130,20 +130,14 @@ class Collection:
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the value in the content or content_blob column
"""
from llm import decode, encode
from llm import encode
content_hash = self.content_hash(value)
existing = list(
self.db["embeddings"].rows_where(
"content_hash = ?", [content_hash], limit=1
)
)
if existing:
# Reuse the embedding from whatever record this is, it might even
# be in a different collection
embedding = decode(existing[0]["embedding"])
else:
embedding = self.model().embed(value)
if self.db["embeddings"].count_where(
"content_hash = ? and collection_id = ?", [content_hash, self.id]
):
return
embedding = self.model().embed(value)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id,
@ -198,39 +192,23 @@ class Collection:
# Calculate hashes first
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
# Any of those hashes already exist?
hashes_and_embeddings = {
row["content_hash"]: row["embedding"]
existing_ids = [
row["id"]
for row in self.db.query(
"""
select content_hash, embedding from embeddings
where content_hash in ({})
select id from embeddings
where collection_id = ? and content_hash in ({})
""".format(
",".join("?" for _ in items_and_hashes)
),
[item_and_hash[1] for item_and_hash in items_and_hashes],
[collection_id]
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
)
}
# We need to embed the ones that don't exist yet
items_to_embed = [
item_and_hash[0]
for item_and_hash in items_and_hashes
if item_and_hash[1] not in hashes_and_embeddings
]
calculated_embeddings = list(
self.model().embed_multi(item[1] for item in items_to_embed)
filtered_batch = [item for item in batch if item[0] not in existing_ids]
embeddings = list(
self.model().embed_multi(item[1] for item in filtered_batch)
)
# This should be a list of all the embeddings from both sources
embeddings = []
calculated_embeddings_iterator = iter(calculated_embeddings)
for item_and_hash in items_and_hashes:
if item_and_hash[1] in hashes_and_embeddings:
embeddings.append(
llm.decode(hashes_and_embeddings[item_and_hash[1]])
)
else:
embeddings.append(next(calculated_embeddings_iterator))
with self.db.conn:
cast(Table, self.db["embeddings"]).insert_all(
(
@ -248,7 +226,9 @@ class Collection:
"metadata": json.dumps(metadata) if metadata else None,
"updated": int(time.time()),
}
for (embedding, (id, value, metadata)) in zip(embeddings, batch)
for (embedding, (id, value, metadata)) in zip(
embeddings, filtered_batch
)
),
replace=True,
)

View file

@ -546,6 +546,7 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
def test_duplicate_content_embedded_only_once(embed_demo):
# content_hash should avoid embedding the same content twice
# per collection
db = sqlite_utils.Database(memory=True)
assert len(embed_demo.embedded_content) == 0
collection = Collection("test", db, model_id="embed-demo")
@ -557,11 +558,11 @@ def test_duplicate_content_embedded_only_once(embed_demo):
collection.embed("1", "hello world")
assert db["embeddings"].count == 2
assert len(embed_demo.embedded_content) == 2
# The same string in another collection should also not be embedded
# The same string in another collection should be embedded
c2 = Collection("test2", db, model_id="embed-demo")
c2.embed("1", "hello world")
assert db["embeddings"].count == 3
assert len(embed_demo.embedded_content) == 2
assert len(embed_demo.embedded_content) == 3
# Same again for embed_multi
collection.embed_multi(
@ -569,50 +570,4 @@ def test_duplicate_content_embedded_only_once(embed_demo):
)
# Should have only embedded one more thing
assert db["embeddings"].count == 4
assert len(embed_demo.embedded_content) == 3
def test_embed_again_with_store(embed_demo, multi_files):
# https://github.com/simonw/llm/issues/224
db_path, files = multi_files
runner = CliRunner(mix_stderr=False)
# First embed all the files without storing them
args = [
"embed-multi",
"files",
"-d",
db_path,
"-m",
"embed-demo",
"--files",
str(files),
"**/*.txt",
]
result = runner.invoke(cli, args)
assert result.exit_code == 0
assert not result.stderr
embeddings_db = sqlite_utils.Database(db_path)
rows = list(embeddings_db.query("select id, content from embeddings"))
assert rows == [
{"id": "file2.txt", "content": None},
{"id": "file1.txt", "content": None},
{"id": "nested/two.txt", "content": None},
{"id": "nested/one.txt", "content": None},
{"id": "nested/more/three.txt", "content": None},
]
assert len(embed_demo.embedded_content) == 5
# Now we run it again with the --store option
result2 = runner.invoke(cli, args + ["--store"])
assert result2.exit_code == 0
assert not result2.stderr
# The rows should have their content now
rows = list(embeddings_db.query("select id, content from embeddings"))
assert rows == [
{"id": "file2.txt", "content": "goodbye world"},
{"id": "file1.txt", "content": "hello world"},
{"id": "nested/two.txt", "content": "two"},
{"id": "nested/one.txt", "content": "one"},
{"id": "nested/more/three.txt", "content": "three"},
]
# But it should not have run any more embedding tasks
assert len(embed_demo.embedded_content) == 5
assert len(embed_demo.embedded_content) == 4