diff --git a/llm/cli.py b/llm/cli.py index 9ed6892..af37feb 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1109,6 +1109,8 @@ def embed( model_obj = collection_obj.model() if model_obj is None: + if model is None: + model = get_default_embedding_model() try: model_obj = get_embedding_model(model) except UnknownModelError: diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index 8088a4a..e9f3d2d 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -558,6 +558,12 @@ def test_default_embedding_model(): result5 = runner.invoke(cli, ["embed-models", "default"]) assert result5.exit_code == 0 assert result5.output == "\n" + # Now set the default and actually use it + result6 = runner.invoke(cli, ["embed-models", "default", "embed-demo"]) + assert result6.exit_code == 0 + result7 = runner.invoke(cli, ["embed", "-c", "hello world"]) + assert result7.exit_code == 0 + assert result7.output == "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n" @pytest.mark.parametrize("default_is_set", (False, True))