diff --git a/docs/python-api.md b/docs/python-api.md index c81df56..bee4785 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -96,3 +96,20 @@ print(response2.text()) You will get back five fun facts about skunks. Access `conversation.responses` for a list of all of the responses that have so far been returned during the conversation. + +## Other functions + +The `llm` top level package includes some useful utility functions. + +### set_alias(alias, model_id) + +The `llm.set_alias()` function can be used to define a new alias: + +```python +import llm + +llm.set_alias("turbo", "gpt-3.5-turbo") +``` +The second argument can be a model identifier or another alias, in which case that alias will be resolved. + +If the `aliases.json` file does not exist or contains invalid JSON it will be created or overwritten. diff --git a/llm/__init__.py b/llm/__init__.py index 190b9d8..1cbbfe9 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -131,3 +131,27 @@ def user_dir(): if llm_user_path: return pathlib.Path(llm_user_path) return pathlib.Path(click.get_app_dir("io.datasette.llm")) + + +def set_alias(alias, model_id_or_alias): + """ + Set an alias to point to the specified model. + """ + path = user_dir() / "aliases.json" + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + path.write_text("{}\n") + try: + current = json.loads(path.read_text()) + except json.decoder.JSONDecodeError: + # We're going to write a valid JSON file in a moment: + current = {} + # Resolve model_id_or_alias to a model_id + try: + model = get_model(model_id_or_alias) + model_id = model.model_id + except UnknownModelError: + # Set the alias to the exact string they provided instead + model_id = model_id_or_alias + current[alias] = model_id + path.write_text(json.dumps(current, indent=4) + "\n") diff --git a/tests/test_aliases.py b/tests/test_aliases.py index 6c256b5..8fc2d8d 100644 --- a/tests/test_aliases.py +++ b/tests/test_aliases.py @@ -1,11 +1,19 @@ from click.testing import CliRunner from llm.cli import cli +import llm import json import pytest +def test_set_alias(): + with pytest.raises(llm.UnknownModelError): + llm.get_model("this-is-a-new-alias") + llm.set_alias("this-is-a-new-alias", "gpt-3.5-turbo") + assert llm.get_model("this-is-a-new-alias").model_id == "gpt-3.5-turbo" + + @pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"])) -def test_aliases_list(args): +def test_cli_aliases_list(args): runner = CliRunner() result = runner.invoke(cli, args) assert result.exit_code == 0 @@ -21,7 +29,7 @@ def test_aliases_list(args): @pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"])) -def test_aliases_list_json(args): +def test_cli_aliases_list_json(args): runner = CliRunner() result = runner.invoke(cli, args + ["--json"]) assert result.exit_code == 0 @@ -36,7 +44,7 @@ def test_aliases_list_json(args): } -def test_aliases_set(user_path): +def test_cli_aliases_set(user_path): # Should be not aliases.json at start assert not (user_path / "aliases.json").exists() runner = CliRunner() @@ -46,14 +54,14 @@ def test_aliases_set(user_path): assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {"foo": "bar"} -def test_aliases_path(user_path): +def test_cli_aliases_path(user_path): runner = CliRunner() result = runner.invoke(cli, ["aliases", "path"]) assert result.exit_code == 0 assert result.output.strip() == str(user_path / "aliases.json") -def test_aliases_remove(user_path): +def test_cli_aliases_remove(user_path): (user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8") runner = CliRunner() result = runner.invoke(cli, ["aliases", "remove", "foo"]) @@ -61,7 +69,7 @@ def test_aliases_remove(user_path): assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {} -def test_aliases_remove_invalid(user_path): +def test_cli_aliases_remove_invalid(user_path): (user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8") runner = CliRunner() result = runner.invoke(cli, ["aliases", "remove", "invalid"]) @@ -70,7 +78,7 @@ def test_aliases_remove_invalid(user_path): @pytest.mark.parametrize("args", (["models"], ["models", "list"])) -def test_aliases_are_registered(user_path, args): +def test_cli_aliases_are_registered(user_path, args): (user_path / "aliases.json").write_text( json.dumps({"foo": "bar", "turbo": "gpt-3.5-turbo"}), "utf-8" )