mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-05 14:00:59 +00:00
Reuse embeddings for hashed content, --store now works on second run - closes #224
This commit is contained in:
parent
52cec1304b
commit
267e2ea999
2 changed files with 87 additions and 22 deletions
|
|
@ -130,14 +130,20 @@ class Collection:
|
|||
metadata (dict, optional): Metadata to be stored
|
||||
store (bool, optional): Whether to store the value in the content or content_blob column
|
||||
"""
|
||||
from llm import encode
|
||||
from llm import decode, encode
|
||||
|
||||
content_hash = self.content_hash(value)
|
||||
if self.db["embeddings"].count_where(
|
||||
"content_hash = ? and collection_id = ?", [content_hash, self.id]
|
||||
):
|
||||
return
|
||||
embedding = self.model().embed(value)
|
||||
existing = list(
|
||||
self.db["embeddings"].rows_where(
|
||||
"content_hash = ?", [content_hash], limit=1
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
# Reuse the embedding from whatever record this is, it might even
|
||||
# be in a different collection
|
||||
embedding = decode(existing[0]["embedding"])
|
||||
else:
|
||||
embedding = self.model().embed(value)
|
||||
cast(Table, self.db["embeddings"]).insert(
|
||||
{
|
||||
"collection_id": self.id,
|
||||
|
|
@ -192,23 +198,39 @@ class Collection:
|
|||
# Calculate hashes first
|
||||
items_and_hashes = [(item, self.content_hash(item[1])) for item in batch]
|
||||
# Any of those hashes already exist?
|
||||
existing_ids = [
|
||||
row["id"]
|
||||
hashes_and_embeddings = {
|
||||
row["content_hash"]: row["embedding"]
|
||||
for row in self.db.query(
|
||||
"""
|
||||
select id from embeddings
|
||||
where collection_id = ? and content_hash in ({})
|
||||
select content_hash, embedding from embeddings
|
||||
where content_hash in ({})
|
||||
""".format(
|
||||
",".join("?" for _ in items_and_hashes)
|
||||
),
|
||||
[collection_id]
|
||||
+ [item_and_hash[1] for item_and_hash in items_and_hashes],
|
||||
[item_and_hash[1] for item_and_hash in items_and_hashes],
|
||||
)
|
||||
}
|
||||
# We need to embed the ones that don't exist yet
|
||||
items_to_embed = [
|
||||
item_and_hash[0]
|
||||
for item_and_hash in items_and_hashes
|
||||
if item_and_hash[1] not in hashes_and_embeddings
|
||||
]
|
||||
filtered_batch = [item for item in batch if item[0] not in existing_ids]
|
||||
embeddings = list(
|
||||
self.model().embed_multi(item[1] for item in filtered_batch)
|
||||
calculated_embeddings = list(
|
||||
self.model().embed_multi(item[1] for item in items_to_embed)
|
||||
)
|
||||
|
||||
# This should be a list of all the embeddings from both sources
|
||||
embeddings = []
|
||||
calculated_embeddings_iterator = iter(calculated_embeddings)
|
||||
for item_and_hash in items_and_hashes:
|
||||
if item_and_hash[1] in hashes_and_embeddings:
|
||||
embeddings.append(
|
||||
llm.decode(hashes_and_embeddings[item_and_hash[1]])
|
||||
)
|
||||
else:
|
||||
embeddings.append(next(calculated_embeddings_iterator))
|
||||
|
||||
with self.db.conn:
|
||||
cast(Table, self.db["embeddings"]).insert_all(
|
||||
(
|
||||
|
|
@ -226,9 +248,7 @@ class Collection:
|
|||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
"updated": int(time.time()),
|
||||
}
|
||||
for (embedding, (id, value, metadata)) in zip(
|
||||
embeddings, filtered_batch
|
||||
)
|
||||
for (embedding, (id, value, metadata)) in zip(embeddings, batch)
|
||||
),
|
||||
replace=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -546,7 +546,6 @@ def test_default_embed_model_errors(user_path, default_is_set, command):
|
|||
|
||||
def test_duplicate_content_embedded_only_once(embed_demo):
|
||||
# content_hash should avoid embedding the same content twice
|
||||
# per collection
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
assert len(embed_demo.embedded_content) == 0
|
||||
collection = Collection("test", db, model_id="embed-demo")
|
||||
|
|
@ -558,11 +557,11 @@ def test_duplicate_content_embedded_only_once(embed_demo):
|
|||
collection.embed("1", "hello world")
|
||||
assert db["embeddings"].count == 2
|
||||
assert len(embed_demo.embedded_content) == 2
|
||||
# The same string in another collection should be embedded
|
||||
# The same string in another collection should also not be embedded
|
||||
c2 = Collection("test2", db, model_id="embed-demo")
|
||||
c2.embed("1", "hello world")
|
||||
assert db["embeddings"].count == 3
|
||||
assert len(embed_demo.embedded_content) == 3
|
||||
assert len(embed_demo.embedded_content) == 2
|
||||
|
||||
# Same again for embed_multi
|
||||
collection.embed_multi(
|
||||
|
|
@ -570,4 +569,50 @@ def test_duplicate_content_embedded_only_once(embed_demo):
|
|||
)
|
||||
# Should have only embedded one more thing
|
||||
assert db["embeddings"].count == 4
|
||||
assert len(embed_demo.embedded_content) == 4
|
||||
assert len(embed_demo.embedded_content) == 3
|
||||
|
||||
|
||||
def test_embed_again_with_store(embed_demo, multi_files):
|
||||
# https://github.com/simonw/llm/issues/224
|
||||
db_path, files = multi_files
|
||||
runner = CliRunner(mix_stderr=False)
|
||||
# First embed all the files without storing them
|
||||
args = [
|
||||
"embed-multi",
|
||||
"files",
|
||||
"-d",
|
||||
db_path,
|
||||
"-m",
|
||||
"embed-demo",
|
||||
"--files",
|
||||
str(files),
|
||||
"**/*.txt",
|
||||
]
|
||||
result = runner.invoke(cli, args)
|
||||
assert result.exit_code == 0
|
||||
assert not result.stderr
|
||||
embeddings_db = sqlite_utils.Database(db_path)
|
||||
rows = list(embeddings_db.query("select id, content from embeddings"))
|
||||
assert rows == [
|
||||
{"id": "file2.txt", "content": None},
|
||||
{"id": "file1.txt", "content": None},
|
||||
{"id": "nested/two.txt", "content": None},
|
||||
{"id": "nested/one.txt", "content": None},
|
||||
{"id": "nested/more/three.txt", "content": None},
|
||||
]
|
||||
assert len(embed_demo.embedded_content) == 5
|
||||
# Now we run it again with the --store option
|
||||
result2 = runner.invoke(cli, args + ["--store"])
|
||||
assert result2.exit_code == 0
|
||||
assert not result2.stderr
|
||||
# The rows should have their content now
|
||||
rows = list(embeddings_db.query("select id, content from embeddings"))
|
||||
assert rows == [
|
||||
{"id": "file2.txt", "content": "goodbye world"},
|
||||
{"id": "file1.txt", "content": "hello world"},
|
||||
{"id": "nested/two.txt", "content": "two"},
|
||||
{"id": "nested/one.txt", "content": "one"},
|
||||
{"id": "nested/more/three.txt", "content": "three"},
|
||||
]
|
||||
# But it should not have run any more embedding tasks
|
||||
assert len(embed_demo.embedded_content) == 5
|
||||
|
|
|
|||
Loading…
Reference in a new issue