mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-20 14:40:25 +00:00
123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
import json
|
|
import llm
|
|
from llm.embeddings import Entry
|
|
import pytest
|
|
import sqlite_utils
|
|
from unittest.mock import ANY
|
|
|
|
|
|
def test_demo_plugin():
|
|
model = llm.get_embedding_model("embed-demo")
|
|
assert model.embed("hello world") == [5, 5] + [0] * 14
|
|
|
|
|
|
def test_embed_huge_list():
|
|
model = llm.get_embedding_model("embed-demo")
|
|
huge_list = ("hello {}".format(i) for i in range(1000))
|
|
results = model.embed_multi(huge_list)
|
|
assert repr(type(results)) == "<class 'generator'>"
|
|
first_twos = {}
|
|
for result in results:
|
|
key = (result[0], result[1])
|
|
first_twos[key] = first_twos.get(key, 0) + 1
|
|
assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}
|
|
# Should have happened in 100 batches
|
|
assert model.batch_count == 100
|
|
|
|
|
|
def test_embed_store(collection):
|
|
collection.embed("3", "hello world", store=True)
|
|
assert collection.db["embeddings"].count == 3
|
|
assert (
|
|
next(collection.db["embeddings"].rows_where("id = ?", ["3"]))["content"]
|
|
== "hello world"
|
|
)
|
|
|
|
|
|
def test_embed_metadata(collection):
|
|
collection.embed("3", "hello yet again", metadata={"foo": "bar"}, store=True)
|
|
assert collection.db["embeddings"].count == 3
|
|
assert json.loads(
|
|
next(collection.db["embeddings"].rows_where("id = ?", ["3"]))["metadata"]
|
|
) == {"foo": "bar"}
|
|
entry = collection.similar("hello yet again")[0]
|
|
assert entry.id == "3"
|
|
assert entry.metadata == {"foo": "bar"}
|
|
assert entry.content == "hello yet again"
|
|
|
|
|
|
def test_collection(collection):
|
|
assert collection.id == 1
|
|
assert collection.count() == 2
|
|
# Check that the embeddings are there
|
|
rows = list(collection.db["embeddings"].rows)
|
|
assert rows == [
|
|
{
|
|
"collection_id": 1,
|
|
"id": "1",
|
|
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
|
"content": None,
|
|
"content_hash": collection.content_hash("hello world"),
|
|
"metadata": None,
|
|
"updated": ANY,
|
|
},
|
|
{
|
|
"collection_id": 1,
|
|
"id": "2",
|
|
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
|
"content": None,
|
|
"content_hash": collection.content_hash("goodbye world"),
|
|
"metadata": None,
|
|
"updated": ANY,
|
|
},
|
|
]
|
|
assert isinstance(rows[0]["updated"], int) and rows[0]["updated"] > 0
|
|
|
|
|
|
def test_similar(collection):
|
|
results = list(collection.similar("hello world"))
|
|
assert results == [
|
|
Entry(id="1", score=pytest.approx(0.9999999999999999)),
|
|
Entry(id="2", score=pytest.approx(0.9863939238321437)),
|
|
]
|
|
|
|
|
|
def test_similar_by_id(collection):
|
|
results = list(collection.similar_by_id("1"))
|
|
assert results == [
|
|
Entry(id="2", score=pytest.approx(0.9863939238321437)),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("with_metadata", (False, True))
|
|
def test_embed_multi(with_metadata):
|
|
db = sqlite_utils.Database(memory=True)
|
|
collection = llm.Collection("test", db, model_id="embed-demo")
|
|
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
|
|
if with_metadata:
|
|
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)
|
|
collection.embed_multi_with_metadata(ids_and_texts)
|
|
else:
|
|
# Exercise store=True here too
|
|
collection.embed_multi(ids_and_texts, store=True)
|
|
rows = list(db["embeddings"].rows)
|
|
assert len(rows) == 1000
|
|
rows_with_metadata = [row for row in rows if row["metadata"] is not None]
|
|
rows_with_content = [row for row in rows if row["content"] is not None]
|
|
if with_metadata:
|
|
assert len(rows_with_metadata) == 1000
|
|
assert len(rows_with_content) == 0
|
|
else:
|
|
assert len(rows_with_metadata) == 0
|
|
assert len(rows_with_content) == 1000
|
|
# Every row should have content_hash set
|
|
assert all(row["content_hash"] is not None for row in rows)
|
|
|
|
|
|
def test_collection_delete(collection):
|
|
db = collection.db
|
|
assert db["embeddings"].count == 2
|
|
assert db["collections"].count == 1
|
|
collection.delete()
|
|
assert db["embeddings"].count == 0
|
|
assert db["collections"].count == 0
|