mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
llm.set_alias() function, refs #154
This commit is contained in:
parent
0cd9333054
commit
8cdcf1e689
3 changed files with 56 additions and 7 deletions
|
|
@ -96,3 +96,20 @@ print(response2.text())
|
|||
You will get back five fun facts about skunks.
|
||||
|
||||
Access `conversation.responses` for a list of all of the responses that have so far been returned during the conversation.
|
||||
|
||||
## Other functions
|
||||
|
||||
The `llm` top level package includes some useful utility functions.
|
||||
|
||||
### set_alias(alias, model_id)
|
||||
|
||||
The `llm.set_alias()` function can be used to define a new alias:
|
||||
|
||||
```python
|
||||
import llm
|
||||
|
||||
llm.set_alias("turbo", "gpt-3.5-turbo")
|
||||
```
|
||||
The second argument can be a model identifier or another alias, in which case that alias will be resolved.
|
||||
|
||||
If the `aliases.json` file does not exist or contains invalid JSON it will be created or overwritten.
|
||||
|
|
|
|||
|
|
@ -131,3 +131,27 @@ def user_dir():
|
|||
if llm_user_path:
|
||||
return pathlib.Path(llm_user_path)
|
||||
return pathlib.Path(click.get_app_dir("io.datasette.llm"))
|
||||
|
||||
|
||||
def set_alias(alias, model_id_or_alias):
|
||||
"""
|
||||
Set an alias to point to the specified model.
|
||||
"""
|
||||
path = user_dir() / "aliases.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
path.write_text("{}\n")
|
||||
try:
|
||||
current = json.loads(path.read_text())
|
||||
except json.decoder.JSONDecodeError:
|
||||
# We're going to write a valid JSON file in a moment:
|
||||
current = {}
|
||||
# Resolve model_id_or_alias to a model_id
|
||||
try:
|
||||
model = get_model(model_id_or_alias)
|
||||
model_id = model.model_id
|
||||
except UnknownModelError:
|
||||
# Set the alias to the exact string they provided instead
|
||||
model_id = model_id_or_alias
|
||||
current[alias] = model_id
|
||||
path.write_text(json.dumps(current, indent=4) + "\n")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
from click.testing import CliRunner
|
||||
from llm.cli import cli
|
||||
import llm
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
def test_set_alias():
|
||||
with pytest.raises(llm.UnknownModelError):
|
||||
llm.get_model("this-is-a-new-alias")
|
||||
llm.set_alias("this-is-a-new-alias", "gpt-3.5-turbo")
|
||||
assert llm.get_model("this-is-a-new-alias").model_id == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
|
||||
def test_aliases_list(args):
|
||||
def test_cli_aliases_list(args):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args)
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -21,7 +29,7 @@ def test_aliases_list(args):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("args", (["aliases", "list"], ["aliases"]))
|
||||
def test_aliases_list_json(args):
|
||||
def test_cli_aliases_list_json(args):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args + ["--json"])
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -36,7 +44,7 @@ def test_aliases_list_json(args):
|
|||
}
|
||||
|
||||
|
||||
def test_aliases_set(user_path):
|
||||
def test_cli_aliases_set(user_path):
|
||||
# Should be not aliases.json at start
|
||||
assert not (user_path / "aliases.json").exists()
|
||||
runner = CliRunner()
|
||||
|
|
@ -46,14 +54,14 @@ def test_aliases_set(user_path):
|
|||
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_aliases_path(user_path):
|
||||
def test_cli_aliases_path(user_path):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["aliases", "path"])
|
||||
assert result.exit_code == 0
|
||||
assert result.output.strip() == str(user_path / "aliases.json")
|
||||
|
||||
|
||||
def test_aliases_remove(user_path):
|
||||
def test_cli_aliases_remove(user_path):
|
||||
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["aliases", "remove", "foo"])
|
||||
|
|
@ -61,7 +69,7 @@ def test_aliases_remove(user_path):
|
|||
assert json.loads((user_path / "aliases.json").read_text("utf-8")) == {}
|
||||
|
||||
|
||||
def test_aliases_remove_invalid(user_path):
|
||||
def test_cli_aliases_remove_invalid(user_path):
|
||||
(user_path / "aliases.json").write_text(json.dumps({"foo": "bar"}), "utf-8")
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["aliases", "remove", "invalid"])
|
||||
|
|
@ -70,7 +78,7 @@ def test_aliases_remove_invalid(user_path):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("args", (["models"], ["models", "list"]))
|
||||
def test_aliases_are_registered(user_path, args):
|
||||
def test_cli_aliases_are_registered(user_path, args):
|
||||
(user_path / "aliases.json").write_text(
|
||||
json.dumps({"foo": "bar", "turbo": "gpt-3.5-turbo"}), "utf-8"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue