From 591ad6f57139e5bda868728c266ef46aef31671d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 11 Sep 2023 22:56:52 -0700 Subject: [PATCH] Revert "Reuse embeddings for hashed content, --store now works on second run - closes #224" This reverts commit 267e2ea999992bbd019406cb32653b732e64a561. It's broken, see: https://github.com/simonw/llm/issues/224#issuecomment-1715014393 --- llm/embeddings.py | 56 +++++++++++++---------------------------- tests/test_embed_cli.py | 53 +++----------------------------------- 2 files changed, 22 insertions(+), 87 deletions(-) diff --git a/llm/embeddings.py b/llm/embeddings.py index 69ea5bd..8059e10 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -130,20 +130,14 @@ class Collection: metadata (dict, optional): Metadata to be stored store (bool, optional): Whether to store the value in the content or content_blob column """ - from llm import decode, encode + from llm import encode content_hash = self.content_hash(value) - existing = list( - self.db["embeddings"].rows_where( - "content_hash = ?", [content_hash], limit=1 - ) - ) - if existing: - # Reuse the embedding from whatever record this is, it might even - # be in a different collection - embedding = decode(existing[0]["embedding"]) - else: - embedding = self.model().embed(value) + if self.db["embeddings"].count_where( + "content_hash = ? and collection_id = ?", [content_hash, self.id] + ): + return + embedding = self.model().embed(value) cast(Table, self.db["embeddings"]).insert( { "collection_id": self.id, @@ -198,39 +192,23 @@ class Collection: # Calculate hashes first items_and_hashes = [(item, self.content_hash(item[1])) for item in batch] # Any of those hashes already exist? - hashes_and_embeddings = { - row["content_hash"]: row["embedding"] + existing_ids = [ + row["id"] for row in self.db.query( """ - select content_hash, embedding from embeddings - where content_hash in ({}) + select id from embeddings + where collection_id = ? and content_hash in ({}) """.format( ",".join("?" for _ in items_and_hashes) ), - [item_and_hash[1] for item_and_hash in items_and_hashes], + [collection_id] + + [item_and_hash[1] for item_and_hash in items_and_hashes], ) - } - # We need to embed the ones that don't exist yet - items_to_embed = [ - item_and_hash[0] - for item_and_hash in items_and_hashes - if item_and_hash[1] not in hashes_and_embeddings ] - calculated_embeddings = list( - self.model().embed_multi(item[1] for item in items_to_embed) + filtered_batch = [item for item in batch if item[0] not in existing_ids] + embeddings = list( + self.model().embed_multi(item[1] for item in filtered_batch) ) - - # This should be a list of all the embeddings from both sources - embeddings = [] - calculated_embeddings_iterator = iter(calculated_embeddings) - for item_and_hash in items_and_hashes: - if item_and_hash[1] in hashes_and_embeddings: - embeddings.append( - llm.decode(hashes_and_embeddings[item_and_hash[1]]) - ) - else: - embeddings.append(next(calculated_embeddings_iterator)) - with self.db.conn: cast(Table, self.db["embeddings"]).insert_all( ( @@ -248,7 +226,9 @@ class Collection: "metadata": json.dumps(metadata) if metadata else None, "updated": int(time.time()), } - for (embedding, (id, value, metadata)) in zip(embeddings, batch) + for (embedding, (id, value, metadata)) in zip( + embeddings, filtered_batch + ) ), replace=True, ) diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 8d021b4..1e2ad04 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -546,6 +546,7 @@ def test_default_embed_model_errors(user_path, default_is_set, command): def test_duplicate_content_embedded_only_once(embed_demo): # content_hash should avoid embedding the same content twice + # per collection db = sqlite_utils.Database(memory=True) assert len(embed_demo.embedded_content) == 0 collection = Collection("test", db, model_id="embed-demo") @@ -557,11 +558,11 @@ def test_duplicate_content_embedded_only_once(embed_demo): collection.embed("1", "hello world") assert db["embeddings"].count == 2 assert len(embed_demo.embedded_content) == 2 - # The same string in another collection should also not be embedded + # The same string in another collection should be embedded c2 = Collection("test2", db, model_id="embed-demo") c2.embed("1", "hello world") assert db["embeddings"].count == 3 - assert len(embed_demo.embedded_content) == 2 + assert len(embed_demo.embedded_content) == 3 # Same again for embed_multi collection.embed_multi( @@ -569,50 +570,4 @@ def test_duplicate_content_embedded_only_once(embed_demo): ) # Should have only embedded one more thing assert db["embeddings"].count == 4 - assert len(embed_demo.embedded_content) == 3 - - -def test_embed_again_with_store(embed_demo, multi_files): - # https://github.com/simonw/llm/issues/224 - db_path, files = multi_files - runner = CliRunner(mix_stderr=False) - # First embed all the files without storing them - args = [ - "embed-multi", - "files", - "-d", - db_path, - "-m", - "embed-demo", - "--files", - str(files), - "**/*.txt", - ] - result = runner.invoke(cli, args) - assert result.exit_code == 0 - assert not result.stderr - embeddings_db = sqlite_utils.Database(db_path) - rows = list(embeddings_db.query("select id, content from embeddings")) - assert rows == [ - {"id": "file2.txt", "content": None}, - {"id": "file1.txt", "content": None}, - {"id": "nested/two.txt", "content": None}, - {"id": "nested/one.txt", "content": None}, - {"id": "nested/more/three.txt", "content": None}, - ] - assert len(embed_demo.embedded_content) == 5 - # Now we run it again with the --store option - result2 = runner.invoke(cli, args + ["--store"]) - assert result2.exit_code == 0 - assert not result2.stderr - # The rows should have their content now - rows = list(embeddings_db.query("select id, content from embeddings")) - assert rows == [ - {"id": "file2.txt", "content": "goodbye world"}, - {"id": "file1.txt", "content": "hello world"}, - {"id": "nested/two.txt", "content": "two"}, - {"id": "nested/one.txt", "content": "one"}, - {"id": "nested/more/three.txt", "content": "three"}, - ] - # But it should not have run any more embedding tasks - assert len(embed_demo.embedded_content) == 5 + assert len(embed_demo.embedded_content) == 4