diff --git a/llm/embeddings.py b/llm/embeddings.py index 8059e10..69ea5bd 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -130,14 +130,20 @@ class Collection: metadata (dict, optional): Metadata to be stored store (bool, optional): Whether to store the value in the content or content_blob column """ - from llm import encode + from llm import decode, encode content_hash = self.content_hash(value) - if self.db["embeddings"].count_where( - "content_hash = ? and collection_id = ?", [content_hash, self.id] - ): - return - embedding = self.model().embed(value) + existing = list( + self.db["embeddings"].rows_where( + "content_hash = ?", [content_hash], limit=1 + ) + ) + if existing: + # Reuse the embedding from whatever record this is, it might even + # be in a different collection + embedding = decode(existing[0]["embedding"]) + else: + embedding = self.model().embed(value) cast(Table, self.db["embeddings"]).insert( { "collection_id": self.id, @@ -192,23 +198,39 @@ class Collection: # Calculate hashes first items_and_hashes = [(item, self.content_hash(item[1])) for item in batch] # Any of those hashes already exist? - existing_ids = [ - row["id"] + hashes_and_embeddings = { + row["content_hash"]: row["embedding"] for row in self.db.query( """ - select id from embeddings - where collection_id = ? and content_hash in ({}) + select content_hash, embedding from embeddings + where content_hash in ({}) """.format( ",".join("?" for _ in items_and_hashes) ), - [collection_id] - + [item_and_hash[1] for item_and_hash in items_and_hashes], + [item_and_hash[1] for item_and_hash in items_and_hashes], ) + } + # We need to embed the ones that don't exist yet + items_to_embed = [ + item_and_hash[0] + for item_and_hash in items_and_hashes + if item_and_hash[1] not in hashes_and_embeddings ] - filtered_batch = [item for item in batch if item[0] not in existing_ids] - embeddings = list( - self.model().embed_multi(item[1] for item in filtered_batch) + calculated_embeddings = list( + self.model().embed_multi(item[1] for item in items_to_embed) ) + + # This should be a list of all the embeddings from both sources + embeddings = [] + calculated_embeddings_iterator = iter(calculated_embeddings) + for item_and_hash in items_and_hashes: + if item_and_hash[1] in hashes_and_embeddings: + embeddings.append( + llm.decode(hashes_and_embeddings[item_and_hash[1]]) + ) + else: + embeddings.append(next(calculated_embeddings_iterator)) + with self.db.conn: cast(Table, self.db["embeddings"]).insert_all( ( @@ -226,9 +248,7 @@ class Collection: "metadata": json.dumps(metadata) if metadata else None, "updated": int(time.time()), } - for (embedding, (id, value, metadata)) in zip( - embeddings, filtered_batch - ) + for (embedding, (id, value, metadata)) in zip(embeddings, batch) ), replace=True, ) diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 1e2ad04..8d021b4 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -546,7 +546,6 @@ def test_default_embed_model_errors(user_path, default_is_set, command): def test_duplicate_content_embedded_only_once(embed_demo): # content_hash should avoid embedding the same content twice - # per collection db = sqlite_utils.Database(memory=True) assert len(embed_demo.embedded_content) == 0 collection = Collection("test", db, model_id="embed-demo") @@ -558,11 +557,11 @@ def test_duplicate_content_embedded_only_once(embed_demo): collection.embed("1", "hello world") assert db["embeddings"].count == 2 assert len(embed_demo.embedded_content) == 2 - # The same string in another collection should be embedded + # The same string in another collection should also not be embedded c2 = Collection("test2", db, model_id="embed-demo") c2.embed("1", "hello world") assert db["embeddings"].count == 3 - assert len(embed_demo.embedded_content) == 3 + assert len(embed_demo.embedded_content) == 2 # Same again for embed_multi collection.embed_multi( @@ -570,4 +569,50 @@ def test_duplicate_content_embedded_only_once(embed_demo): ) # Should have only embedded one more thing assert db["embeddings"].count == 4 - assert len(embed_demo.embedded_content) == 4 + assert len(embed_demo.embedded_content) == 3 + + +def test_embed_again_with_store(embed_demo, multi_files): + # https://github.com/simonw/llm/issues/224 + db_path, files = multi_files + runner = CliRunner(mix_stderr=False) + # First embed all the files without storing them + args = [ + "embed-multi", + "files", + "-d", + db_path, + "-m", + "embed-demo", + "--files", + str(files), + "**/*.txt", + ] + result = runner.invoke(cli, args) + assert result.exit_code == 0 + assert not result.stderr + embeddings_db = sqlite_utils.Database(db_path) + rows = list(embeddings_db.query("select id, content from embeddings")) + assert rows == [ + {"id": "file2.txt", "content": None}, + {"id": "file1.txt", "content": None}, + {"id": "nested/two.txt", "content": None}, + {"id": "nested/one.txt", "content": None}, + {"id": "nested/more/three.txt", "content": None}, + ] + assert len(embed_demo.embedded_content) == 5 + # Now we run it again with the --store option + result2 = runner.invoke(cli, args + ["--store"]) + assert result2.exit_code == 0 + assert not result2.stderr + # The rows should have their content now + rows = list(embeddings_db.query("select id, content from embeddings")) + assert rows == [ + {"id": "file2.txt", "content": "goodbye world"}, + {"id": "file1.txt", "content": "hello world"}, + {"id": "nested/two.txt", "content": "two"}, + {"id": "nested/one.txt", "content": "one"}, + {"id": "nested/more/three.txt", "content": "three"}, + ] + # But it should not have run any more embedding tasks + assert len(embed_demo.embedded_content) == 5