diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 4de0563..9ab7f18 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -216,3 +216,41 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario): "content": None, "metadata": None, } + + +@pytest.mark.parametrize("use_stdin", (False, True)) +@pytest.mark.parametrize( + "filename,content", + ( + ("phrases.csv", "id,phrase\n1,hello world\n2,goodbye world"), + ("phrases.tsv", "id\tphrase\n1\thello world\n2\tgoodbye world"), + ( + "phrases.jsonl", + '{"id": 1, "phrase": "hello world"}\n{"id": 2, "phrase": "goodbye world"}', + ), + ( + "phrases.json", + '[{"id": 1, "phrase": "hello world"}, {"id": 2, "phrase": "goodbye world"}]', + ), + ), +) +def test_embed_multi_file_input(tmpdir, use_stdin, filename, content): + db_path = tmpdir / "embeddings.db" + args = ["embed-multi", "phrases", "-d", str(db_path), "-m", "embed-demo"] + input = None + if use_stdin: + input = content + args.append("-") + else: + path = tmpdir / filename + path.write_text(content, "utf-8") + args.append(str(path)) + # Auto-detection can't detect JSON-nl, so make that explicit + if filename.endswith(".jsonl"): + args.extend(("--format", "nl")) + runner = CliRunner() + result = runner.invoke(cli, args, input=input) + assert result.exit_code == 0 + # Check that everything was embedded correctly + db = sqlite_utils.Database(str(db_path)) + assert db["embeddings"].count == 2