Actually register aliases from aliases.json, closes #151

This commit is contained in:
Simon Willison 2023-08-12 09:15:17 -07:00
parent eef2ab0d79
commit 458a59a8f4
2 changed files with 23 additions and 1 deletions

View file

@ -54,10 +54,22 @@ def get_plugins():
def get_models_with_aliases() -> List["ModelWithAliases"]:
model_aliases = []
# Include aliases from aliases.json
aliases_path = user_dir() / "aliases.json"
extra_model_aliases = {}
if aliases_path.exists():
configured_aliases = json.loads(aliases_path.read_text())
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)
def register(model, aliases=None):
model_aliases.append(ModelWithAliases(model, aliases or set()))
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(ModelWithAliases(model, alias_list))
pm.hook.register_models(register=register)
return model_aliases

View file

@ -64,3 +64,13 @@ def test_aliases_remove_invalid(user_path):
result = runner.invoke(cli, ["aliases", "remove", "invalid"])
assert result.exit_code == 1
assert result.output == "Error: Alias not found: invalid\n"
def test_aliases_are_registered(user_path):
(user_path / "aliases.json").write_text(
json.dumps({"foo": "bar", "turbo": "gpt-3.5-turbo"}), "utf-8"
)
runner = CliRunner()
result = runner.invoke(cli, ["models", "list"])
assert result.exit_code == 0
assert "gpt-3.5-turbo (aliases: 3.5, chatgpt, turbo)" in result.output