mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
* Embeddings plugin hook + OpenAI implementation * llm.get_embedding_model(name) function * llm embed command, for returning embeddings or saving them to SQLite * Tests using an EmbedDemo embedding model * llm embed-models list and emeb-models default commands * llm embed-db path and llm embed-db collections commands
105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
from click.testing import CliRunner
|
|
from llm.cli import cli
|
|
import json
|
|
import pytest
|
|
import sqlite_utils
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"format_,expected",
|
|
(
|
|
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
|
|
(
|
|
"base64",
|
|
(
|
|
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
|
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
|
|
),
|
|
),
|
|
(
|
|
"hex",
|
|
(
|
|
"0000a0400000a04000000000000000000000000000000000000000000"
|
|
"000000000000000000000000000000000000000000000000000000000"
|
|
"00000000000000\n"
|
|
),
|
|
),
|
|
(
|
|
"blob",
|
|
(
|
|
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
|
|
).decode("utf-8"),
|
|
),
|
|
),
|
|
)
|
|
def test_embed_output_format(format_, expected):
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
cli, ["embed", "--format", format_, "-c", "hello world", "-m", "embed-demo"]
|
|
)
|
|
assert result.exit_code == 0
|
|
assert result.output == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"args,expected_error",
|
|
((["-c", "Content", "stories"], "Must provide both collection and id"),),
|
|
)
|
|
def test_embed_errors(args, expected_error):
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["embed"] + args)
|
|
assert result.exit_code == 1
|
|
assert expected_error in result.output
|
|
|
|
|
|
def test_embed_store(user_path):
|
|
embeddings_db = user_path / "embeddings.db"
|
|
assert not embeddings_db.exists()
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
|
|
assert result.exit_code == 0
|
|
# Should not have created the table
|
|
assert not embeddings_db.exists()
|
|
# Now run it to store
|
|
result = runner.invoke(
|
|
cli, ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
|
|
)
|
|
assert result.exit_code == 0
|
|
assert embeddings_db.exists()
|
|
# Check the contents
|
|
db = sqlite_utils.Database(str(embeddings_db))
|
|
assert list(db["collections"].rows) == [
|
|
{"id": 1, "name": "items", "model": "embed-demo"}
|
|
]
|
|
assert list(db["embeddings"].rows) == [
|
|
{
|
|
"collection_id": 1,
|
|
"id": "1",
|
|
"embedding": (
|
|
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00"
|
|
),
|
|
"content": None,
|
|
"metadata": None,
|
|
}
|
|
]
|
|
# Should show up in 'llm embed-db collections'
|
|
for is_json in (False, True):
|
|
args = ["embed-db", "collections"]
|
|
if is_json:
|
|
args.extend(["--json"])
|
|
result2 = runner.invoke(cli, args)
|
|
assert result2.exit_code == 0
|
|
if is_json:
|
|
assert json.loads(result2.output) == [
|
|
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
|
|
]
|
|
else:
|
|
assert result2.output == "items: embed-demo\n 1 embedding\n"
|