From b9c19a56661023792cd29ebc2f86b99fe1a14785 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 3 Sep 2023 16:17:10 -0700 Subject: [PATCH] Tests for multiple --files pairs --- tests/test_embed_cli.py | 43 ++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 4fb0987..71d1006 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -314,7 +314,8 @@ def test_embed_multi_sql(tmpdir, use_other_db, prefix): ] -def test_embed_multi_files(tmpdir): +@pytest.mark.parametrize("scenario", ("single", "multi")) +def test_embed_multi_files(tmpdir, scenario): db_path = str(tmpdir / "files.db") files = tmpdir / "files" for filename, content in ( @@ -329,6 +330,18 @@ def test_embed_multi_files(tmpdir): path.parent.mkdir(parents=True, exist_ok=True) path.write_text(content, "utf-8") + if scenario == "single": + extra_args = ["--files", str(files), "**/*.txt"] + else: + extra_args = [ + "--files", + str(files / "nested" / "more"), + "**/*.ini", + "--files", + str(files / "nested"), + "*.txt", + ] + runner = CliRunner() result = runner.invoke( cli, @@ -337,25 +350,29 @@ def test_embed_multi_files(tmpdir): "files", "-d", db_path, - "--files", - str(files), - "**/*.txt", "-m", "embed-demo", "--store", - ], + ] + + extra_args, ) assert result.exit_code == 0 embeddings_db = sqlite_utils.Database(db_path) - assert embeddings_db["embeddings"].count == 5 rows = list(embeddings_db.query("select id, content from embeddings")) - assert rows == [ - {"id": "file2.txt", "content": "goodbye world"}, - {"id": "file1.txt", "content": "hello world"}, - {"id": "nested/two.txt", "content": "two"}, - {"id": "nested/one.txt", "content": "one"}, - {"id": "nested/more/three.txt", "content": "three"}, - ] + if scenario == "single": + assert rows == [ + {"id": "file2.txt", "content": "goodbye world"}, + {"id": "file1.txt", "content": "hello world"}, + {"id": "nested/two.txt", "content": "two"}, + {"id": "nested/one.txt", "content": "one"}, + {"id": "nested/more/three.txt", "content": "three"}, + ] + else: + assert rows == [ + {"id": "ignored.ini", "content": "Does not match glob"}, + {"id": "two.txt", "content": "two"}, + {"id": "one.txt", "content": "one"}, + ] def test_default_embedding_model():