mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
from click.testing import CliRunner
|
|
from llm.cli import cli
|
|
from llm import Collection
|
|
import json
|
|
import pytest
|
|
import sqlite_utils
|
|
from unittest.mock import ANY
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"format_,expected",
|
|
(
|
|
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
|
|
(
|
|
"base64",
|
|
(
|
|
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
|
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
|
|
),
|
|
),
|
|
(
|
|
"hex",
|
|
(
|
|
"0000a0400000a04000000000000000000000000000000000000000000"
|
|
"000000000000000000000000000000000000000000000000000000000"
|
|
"00000000000000\n"
|
|
),
|
|
),
|
|
(
|
|
"blob",
|
|
(
|
|
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
|
|
).decode("utf-8"),
|
|
),
|
|
),
|
|
)
|
|
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
|
|
def test_embed_output_format(tmpdir, format_, expected, scenario):
|
|
runner = CliRunner()
|
|
args = ["embed", "--format", format_, "-m", "embed-demo"]
|
|
input = None
|
|
if scenario == "argument":
|
|
args.extend(["-c", "hello world"])
|
|
elif scenario == "file":
|
|
path = tmpdir / "input.txt"
|
|
path.write_text("hello world", "utf-8")
|
|
args.extend(["-i", str(path)])
|
|
elif scenario == "stdin":
|
|
input = "hello world"
|
|
args.extend(["-i", "-"])
|
|
result = runner.invoke(cli, args, input=input)
|
|
assert result.exit_code == 0
|
|
assert result.output == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"args,expected_error",
|
|
((["-c", "Content", "stories"], "Must provide both collection and id"),),
|
|
)
|
|
def test_embed_errors(args, expected_error):
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["embed"] + args)
|
|
assert result.exit_code == 1
|
|
assert expected_error in result.output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"metadata,metadata_error",
|
|
(
|
|
(None, None),
|
|
('{"foo": "bar"}', None),
|
|
('{"foo": [1, 2, 3]}', None),
|
|
("[1, 2, 3]", "Metadata must be a JSON object"), # Must be a dictionary
|
|
('{"foo": "incomplete}', "Metadata must be valid JSON"),
|
|
),
|
|
)
|
|
def test_embed_store(user_path, metadata, metadata_error):
|
|
embeddings_db = user_path / "embeddings.db"
|
|
assert not embeddings_db.exists()
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
|
|
assert result.exit_code == 0
|
|
# Should not have created the table
|
|
assert not embeddings_db.exists()
|
|
# Now run it to store
|
|
args = ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
|
|
if metadata is not None:
|
|
args.extend(("--metadata", metadata))
|
|
result = runner.invoke(cli, args)
|
|
if metadata_error:
|
|
# Should have returned an error message about invalid metadata
|
|
assert result.exit_code == 2
|
|
assert metadata_error in result.output
|
|
return
|
|
# No error, should have succeeded and stored the data
|
|
assert result.exit_code == 0
|
|
assert embeddings_db.exists()
|
|
# Check the contents
|
|
db = sqlite_utils.Database(str(embeddings_db))
|
|
assert list(db["collections"].rows) == [
|
|
{"id": 1, "name": "items", "model": "embed-demo"}
|
|
]
|
|
expected_metadata = None
|
|
if metadata and not metadata_error:
|
|
expected_metadata = metadata
|
|
rows = list(db["embeddings"].rows)
|
|
assert rows == [
|
|
{
|
|
"collection_id": 1,
|
|
"id": "1",
|
|
"embedding": (
|
|
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00\x00\x00\x00"
|
|
),
|
|
"content": None,
|
|
"content_hash": Collection.content_hash("hello"),
|
|
"metadata": expected_metadata,
|
|
"updated": ANY,
|
|
}
|
|
]
|
|
# Should show up in 'llm embed-db collections'
|
|
for is_json in (False, True):
|
|
args = ["embed-db", "collections"]
|
|
if is_json:
|
|
args.extend(["--json"])
|
|
result2 = runner.invoke(cli, args)
|
|
assert result2.exit_code == 0
|
|
if is_json:
|
|
assert json.loads(result2.output) == [
|
|
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
|
|
]
|
|
else:
|
|
assert result2.output == "items: embed-demo\n 1 embedding\n"
|
|
|
|
# And test deleting it too
|
|
result = runner.invoke(cli, ["embed-db", "delete-collection", "items"])
|
|
assert result.exit_code == 0
|
|
assert db["collections"].count == 0
|
|
assert db["embeddings"].count == 0
|
|
|
|
|
|
def test_collection_delete_errors(user_path):
|
|
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
|
|
collection = Collection("items", db, model_id="embed-demo")
|
|
collection.embed("1", "hello")
|
|
assert db["collections"].count == 1
|
|
assert db["embeddings"].count == 1
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["embed-db", "delete-collection", "does-not-exist"])
|
|
assert result.exit_code == 1
|
|
assert "Collection does not exist" in result.output
|
|
assert db["collections"].count == 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"args,expected_error",
|
|
(
|
|
([], "Missing argument 'COLLECTION'"),
|
|
(["badcollection", "-c", "content"], "Collection does not exist"),
|
|
(["demo", "bad-id"], "ID not found in collection"),
|
|
),
|
|
)
|
|
def test_similar_errors(args, expected_error, user_path_with_embeddings):
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["similar"] + args, catch_exceptions=False)
|
|
assert result.exit_code != 0
|
|
assert expected_error in result.output
|
|
|
|
|
|
def test_similar_by_id_cli(user_path_with_embeddings):
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["similar", "demo", "1"], catch_exceptions=False)
|
|
assert result.exit_code == 0
|
|
assert json.loads(result.output) == {
|
|
"id": "2",
|
|
"score": pytest.approx(0.9863939238321437),
|
|
"content": None,
|
|
"metadata": None,
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
|
|
def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
|
|
runner = CliRunner()
|
|
args = ["similar", "demo"]
|
|
input = None
|
|
if scenario == "argument":
|
|
args.extend(["-c", "hello world"])
|
|
elif scenario == "file":
|
|
path = tmpdir / "content.txt"
|
|
path.write_text("hello world", "utf-8")
|
|
args.extend(["-i", str(path)])
|
|
elif scenario == "stdin":
|
|
input = "hello world"
|
|
args.extend(["-i", "-"])
|
|
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
|
|
assert result.exit_code == 0
|
|
lines = [line for line in result.output.splitlines() if line.strip()]
|
|
assert len(lines) == 2
|
|
assert json.loads(lines[0]) == {
|
|
"id": "1",
|
|
"score": pytest.approx(0.9999999999999999),
|
|
"content": None,
|
|
"metadata": None,
|
|
}
|
|
assert json.loads(lines[1]) == {
|
|
"id": "2",
|
|
"score": pytest.approx(0.9863939238321437),
|
|
"content": None,
|
|
"metadata": None,
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("use_stdin", (False, True))
|
|
@pytest.mark.parametrize(
|
|
"filename,content",
|
|
(
|
|
("phrases.csv", "id,phrase\n1,hello world\n2,goodbye world"),
|
|
("phrases.tsv", "id\tphrase\n1\thello world\n2\tgoodbye world"),
|
|
(
|
|
"phrases.jsonl",
|
|
'{"id": 1, "phrase": "hello world"}\n{"id": 2, "phrase": "goodbye world"}',
|
|
),
|
|
(
|
|
"phrases.json",
|
|
'[{"id": 1, "phrase": "hello world"}, {"id": 2, "phrase": "goodbye world"}]',
|
|
),
|
|
),
|
|
)
|
|
def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
|
|
db_path = tmpdir / "embeddings.db"
|
|
args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"]
|
|
input = None
|
|
if use_stdin:
|
|
input = content
|
|
args.append("-")
|
|
else:
|
|
path = tmpdir / filename
|
|
path.write_text(content, "utf-8")
|
|
args.append(str(path))
|
|
# Auto-detection can't detect JSON-nl, so make that explicit
|
|
if filename.endswith(".jsonl"):
|
|
args.extend(("--format", "nl"))
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, args, input=input)
|
|
assert result.exit_code == 0
|
|
# Check that everything was embedded correctly
|
|
db = sqlite_utils.Database(str(db_path))
|
|
assert db["embeddings"].count == 2
|