llm/tests/test_embed_cli.py
2023-09-03 16:40:00 -07:00

256 lines
8.7 KiB
Python

from click.testing import CliRunner
from llm.cli import cli
from llm import Collection
import json
import pytest
import sqlite_utils
from unittest.mock import ANY
@pytest.mark.parametrize(
"format_,expected",
(
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
(
"base64",
(
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
),
),
(
"hex",
(
"0000a0400000a04000000000000000000000000000000000000000000"
"000000000000000000000000000000000000000000000000000000000"
"00000000000000\n"
),
),
(
"blob",
(
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
).decode("utf-8"),
),
),
)
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
def test_embed_output_format(tmpdir, format_, expected, scenario):
runner = CliRunner()
args = ["embed", "--format", format_, "-m", "embed-demo"]
input = None
if scenario == "argument":
args.extend(["-c", "hello world"])
elif scenario == "file":
path = tmpdir / "input.txt"
path.write_text("hello world", "utf-8")
args.extend(["-i", str(path)])
elif scenario == "stdin":
input = "hello world"
args.extend(["-i", "-"])
result = runner.invoke(cli, args, input=input)
assert result.exit_code == 0
assert result.output == expected
@pytest.mark.parametrize(
"args,expected_error",
((["-c", "Content", "stories"], "Must provide both collection and id"),),
)
def test_embed_errors(args, expected_error):
runner = CliRunner()
result = runner.invoke(cli, ["embed"] + args)
assert result.exit_code == 1
assert expected_error in result.output
@pytest.mark.parametrize(
"metadata,metadata_error",
(
(None, None),
('{"foo": "bar"}', None),
('{"foo": [1, 2, 3]}', None),
("[1, 2, 3]", "Metadata must be a JSON object"), # Must be a dictionary
('{"foo": "incomplete}', "Metadata must be valid JSON"),
),
)
def test_embed_store(user_path, metadata, metadata_error):
embeddings_db = user_path / "embeddings.db"
assert not embeddings_db.exists()
runner = CliRunner()
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
assert result.exit_code == 0
# Should not have created the table
assert not embeddings_db.exists()
# Now run it to store
args = ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
if metadata is not None:
args.extend(("--metadata", metadata))
result = runner.invoke(cli, args)
if metadata_error:
# Should have returned an error message about invalid metadata
assert result.exit_code == 2
assert metadata_error in result.output
return
# No error, should have succeeded and stored the data
assert result.exit_code == 0
assert embeddings_db.exists()
# Check the contents
db = sqlite_utils.Database(str(embeddings_db))
assert list(db["collections"].rows) == [
{"id": 1, "name": "items", "model": "embed-demo"}
]
expected_metadata = None
if metadata and not metadata_error:
expected_metadata = metadata
rows = list(db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "1",
"embedding": (
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"content_hash": Collection.content_hash("hello"),
"metadata": expected_metadata,
"updated": ANY,
}
]
# Should show up in 'llm embed-db collections'
for is_json in (False, True):
args = ["embed-db", "collections"]
if is_json:
args.extend(["--json"])
result2 = runner.invoke(cli, args)
assert result2.exit_code == 0
if is_json:
assert json.loads(result2.output) == [
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
]
else:
assert result2.output == "items: embed-demo\n 1 embedding\n"
# And test deleting it too
result = runner.invoke(cli, ["embed-db", "delete-collection", "items"])
assert result.exit_code == 0
assert db["collections"].count == 0
assert db["embeddings"].count == 0
def test_collection_delete_errors(user_path):
db = sqlite_utils.Database(str(user_path / "embeddings.db"))
collection = Collection("items", db, model_id="embed-demo")
collection.embed("1", "hello")
assert db["collections"].count == 1
assert db["embeddings"].count == 1
runner = CliRunner()
result = runner.invoke(cli, ["embed-db", "delete-collection", "does-not-exist"])
assert result.exit_code == 1
assert "Collection does not exist" in result.output
assert db["collections"].count == 1
@pytest.mark.parametrize(
"args,expected_error",
(
([], "Missing argument 'COLLECTION'"),
(["badcollection", "-c", "content"], "Collection does not exist"),
(["demo", "bad-id"], "ID not found in collection"),
),
)
def test_similar_errors(args, expected_error, user_path_with_embeddings):
runner = CliRunner()
result = runner.invoke(cli, ["similar"] + args, catch_exceptions=False)
assert result.exit_code != 0
assert expected_error in result.output
def test_similar_by_id_cli(user_path_with_embeddings):
runner = CliRunner()
result = runner.invoke(cli, ["similar", "demo", "1"], catch_exceptions=False)
assert result.exit_code == 0
assert json.loads(result.output) == {
"id": "2",
"score": pytest.approx(0.9863939238321437),
"content": None,
"metadata": None,
}
@pytest.mark.parametrize("scenario", ("argument", "file", "stdin"))
def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
runner = CliRunner()
args = ["similar", "demo"]
input = None
if scenario == "argument":
args.extend(["-c", "hello world"])
elif scenario == "file":
path = tmpdir / "content.txt"
path.write_text("hello world", "utf-8")
args.extend(["-i", str(path)])
elif scenario == "stdin":
input = "hello world"
args.extend(["-i", "-"])
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
assert result.exit_code == 0
lines = [line for line in result.output.splitlines() if line.strip()]
assert len(lines) == 2
assert json.loads(lines[0]) == {
"id": "1",
"score": pytest.approx(0.9999999999999999),
"content": None,
"metadata": None,
}
assert json.loads(lines[1]) == {
"id": "2",
"score": pytest.approx(0.9863939238321437),
"content": None,
"metadata": None,
}
@pytest.mark.parametrize("use_stdin", (False, True))
@pytest.mark.parametrize(
"filename,content",
(
("phrases.csv", "id,phrase\n1,hello world\n2,goodbye world"),
("phrases.tsv", "id\tphrase\n1\thello world\n2\tgoodbye world"),
(
"phrases.jsonl",
'{"id": 1, "phrase": "hello world"}\n{"id": 2, "phrase": "goodbye world"}',
),
(
"phrases.json",
'[{"id": 1, "phrase": "hello world"}, {"id": 2, "phrase": "goodbye world"}]',
),
),
)
def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"]
input = None
if use_stdin:
input = content
args.append("-")
else:
path = tmpdir / filename
path.write_text(content, "utf-8")
args.append(str(path))
# Auto-detection can't detect JSON-nl, so make that explicit
if filename.endswith(".jsonl"):
args.extend(("--format", "nl"))
runner = CliRunner()
result = runner.invoke(cli, args, input=input)
assert result.exit_code == 0
# Check that everything was embedded correctly
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 2