Tests for CSV/TSV/JSON/NL, refs #215

This commit is contained in:
Simon Willison 2023-09-03 15:15:19 -07:00
parent 2440eb4f48
commit 5e686fe8b3

View file

@ -216,3 +216,41 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
"content": None,
"metadata": None,
}
@pytest.mark.parametrize("use_stdin", (False, True))
@pytest.mark.parametrize(
"filename,content",
(
("phrases.csv", "id,phrase\n1,hello world\n2,goodbye world"),
("phrases.tsv", "id\tphrase\n1\thello world\n2\tgoodbye world"),
(
"phrases.jsonl",
'{"id": 1, "phrase": "hello world"}\n{"id": 2, "phrase": "goodbye world"}',
),
(
"phrases.json",
'[{"id": 1, "phrase": "hello world"}, {"id": 2, "phrase": "goodbye world"}]',
),
),
)
def test_embed_multi_file_input(tmpdir, use_stdin, filename, content):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"]
input = None
if use_stdin:
input = content
args.append("-")
else:
path = tmpdir / filename
path.write_text(content, "utf-8")
args.append(str(path))
# Auto-detection can't detect JSON-nl, so make that explicit
if filename.endswith(".jsonl"):
args.extend(("--format", "nl"))
runner = CliRunner()
result = runner.invoke(cli, args, input=input)
assert result.exit_code == 0
# Check that everything was embedded correctly
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 2