llm.set_alias() function, refs #154

This commit is contained in:
Simon Willison 2023-08-19 22:00:38 -07:00
parent 0cd9333054
commit 8cdcf1e689
3 changed files with 56 additions and 7 deletions

View file

@ -96,3 +96,20 @@ print(response2.text())
You will get back five fun facts about skunks.
Access `conversation.responses` for a list of all of the responses that have so far been returned during the conversation.
## Other functions
The `llm` top level package includes some useful utility functions.
### set_alias(alias, model_id)
The `llm.set_alias()` function can be used to define a new alias:
```python
import llm
llm.set_alias("turbo", "gpt-3.5-turbo")
```
The second argument can be a model identifier or another alias, in which case that alias will be resolved.
If the `aliases.json` file does not exist or contains invalid JSON it will be created or overwritten.

View file

@ -131,3 +131,27 @@ def user_dir():
if llm_user_path:
return pathlib.Path(llm_user_path)
return pathlib.Path(click.get_app_dir("io.datasette.llm"))
def set_alias(alias, model_id_or_alias):
"""
Set an alias to point to the specified model.
"""
path = user_dir() / "aliases.json"
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
path.write_text("{}\n")
try:
current = json.loads(path.read_text())
except json.decoder.JSONDecodeError:
# We're going to write a valid JSON file in a moment:
current = {}
# Resolve model_id_or_alias to a model_id
try:
model = get_model(model_id_or_alias)
model_id = model.model_id
except UnknownModelError:
# Set the alias to the exact string they provided instead
model_id = model_id_or_alias
current[alias] = model_id
path.write_text(json.dumps(current, indent=4) + "\n")

View file

@ -1,11 +1,19 @@
from click.testing import CliRunner
from llm.cli import cli
import llm
import json
import pytest
def test_set_alias():
with pytest.raises(llm.UnknownModelError):
llm.get_model("this-is-a-new-alias")
llm.set_alias("this-is-a-new-alias", "gpt-3.5-turbo")
assert llm.get_model("this-is-a-new-alias").model_id == "gpt-3.5-turbo"
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
def test_aliases_list(args):
def test_cli_aliases_list(args):
runner = CliRunner()
result = runner.invoke(cli, args)
assert result.exit_code == 0
@ -21,7 +29,7 @@ def test_aliases_list(args):
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
def test_aliases_list_json(args):
def test_cli_aliases_list_json(args):
runner = CliRunner()
result = runner.invoke(cli, args + ["--json"])
assert result.exit_code == 0
@ -36,7 +44,7 @@ def test_aliases_list_json(args):
}
def test_aliases_set(user_path):
def test_cli_aliases_set(user_path):
# Should be not aliases.json at start
assert not (user_path / "aliases.json").exists()
runner = CliRunner()
@ -46,14 +54,14 @@ def test_aliases_set(user_path):
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {"foo": "bar"}
def test_aliases_path(user_path):
def test_cli_aliases_path(user_path):
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "path"])
assert result.exit_code == 0
assert result.output.strip() == str(user_path / "aliases.json")
def test_aliases_remove(user_path):
def test_cli_aliases_remove(user_path):
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "remove", "foo"])
@ -61,7 +69,7 @@ def test_aliases_remove(user_path):
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {}
def test_aliases_remove_invalid(user_path):
def test_cli_aliases_remove_invalid(user_path):
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
runner = CliRunner()
result = runner.invoke(cli, ["aliases", "remove", "invalid"])
@ -70,7 +78,7 @@ def test_aliases_remove_invalid(user_path):
@pytest.mark.parametrize("args", (["models"], ["models", "list"]))
def test_aliases_are_registered(user_path, args):
def test_cli_aliases_are_registered(user_path, args):
(user_path / "aliases.json").write_text(
json.dumps({"foo": "bar", "turbo": "gpt-3.5-turbo"}), "utf-8"
)