Refactor get_key() with docs and better varible names

The actual logic is unchanged, but it is a lot easier to understand what it does now.

Refs #158
This commit is contained in:
Simon Willison 2023-08-17 15:03:55 -07:00
parent a4f55e9987
commit d31d97e06f

View file

@ -14,7 +14,7 @@ from .models import (
from .templates import Template
from .plugins import pm
import click
from typing import Dict, List
from typing import Dict, List, Optional
import json
import os
import pathlib
@ -94,15 +94,28 @@ def get_model(name):
raise UnknownModelError("Unknown model: " + name)
def get_key(key_arg, default_key, env_var=None):
keys = load_keys()
if key_arg in keys:
return keys[key_arg]
if key_arg:
return key_arg
def get_key(
explicit_key: Optional[str], key_alias: str, env_var: Optional[str] = None
) -> Optional[str]:
"""
Return an API key based on a hierarchy of potential sources.
:param provided_key: A key provided by the user. This may be the key, or an alias of a key in keys.json.
:param key_alias: The alias used to retrieve the key from the keys.json file.
:param env_var: Name of the environment variable to check for the key.
"""
stored_keys = load_keys()
# If user specified an alias, use the key stored for that alias
if explicit_key in stored_keys:
return stored_keys[explicit_key]
if explicit_key:
# User specified a key that's not an alias, use that
return explicit_key
# Environment variables over-ride the default key
if env_var and os.environ.get(env_var):
return os.environ[env_var]
return keys.get(default_key)
# Return the key stored for the default alias
return stored_keys.get(key_alias)
def load_keys():