Fix for bug where embed did not use default model, closes #317

This commit is contained in:
Simon Willison 2023-10-31 21:19:59 -07:00
parent abcb457b20
commit 8b78ac6099
2 changed files with 8 additions and 0 deletions

View file

@ -1109,6 +1109,8 @@ def embed(
model_obj = collection_obj.model()
if model_obj is None:
if model is None:
model = get_default_embedding_model()
try:
model_obj = get_embedding_model(model)
except UnknownModelError:

View file

@ -558,6 +558,12 @@ def test_default_embedding_model():
result5 = runner.invoke(cli, ["embed-models", "default"])
assert result5.exit_code == 0
assert result5.output == "<No default embedding model set>\n"
# Now set the default and actually use it
result6 = runner.invoke(cli, ["embed-models", "default", "embed-demo"])
assert result6.exit_code == 0
result7 = runner.invoke(cli, ["embed", "-c", "hello world"])
assert result7.exit_code == 0
assert result7.output == "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
@pytest.mark.parametrize("default_is_set", (False, True))