Improved how keys work, execute() now has default implementation

This commit is contained in:
Simon Willison 2023-07-05 18:25:57 -07:00
parent f193468f76
commit 6ef6b343a9
4 changed files with 79 additions and 31 deletions

View file

@ -199,8 +199,8 @@ def prompt(
except KeyError:
raise click.ClickException("'{}' is not a known model".format(model_id))
# Provide the API key, if one is needed
if model.needs_key and not model.key:
# Provide the API key, if one is needed and has been provided
if model.needs_key:
model.key = get_key(key, model.needs_key, model.key_env_var)
# Validate options
@ -220,15 +220,17 @@ def prompt(
if not should_stream:
validated_options["stream"] = False
response = model.prompt(prompt, system, **validated_options)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
print(response.text())
try:
response = model.prompt(prompt, system, **validated_options)
if should_stream:
for chunk in response:
print(chunk, end="")
sys.stdout.flush()
print("")
else:
print(response.text())
except Exception as ex:
raise click.ClickException(str(ex))
# Log to the database
if no_log:
@ -543,13 +545,7 @@ def get_key(key_arg, default_key, env_var=None):
return key_arg
if env_var and os.environ.get(env_var):
return os.environ[env_var]
default = keys.get(default_key)
if not default:
message = "No key found - add one using 'llm keys set {}'".format(default_key)
if env_var:
message += " or set the {} environment variable".format(env_var)
raise click.ClickException(message)
return default
return keys.get(default_key)
def load_keys():

View file

@ -0,0 +1,49 @@
from llm import Model, Prompt, hookimpl
import llm
from collections import defaultdict
import random
import time
@hookimpl
def register_models(register):
register(Markov())
class Markov(Model):
can_stream = True
model_id = "markov"
class Options(Model.Options):
length: int = 100
class Response(llm.Response):
def iter_prompt(self):
self._prompt_json = {"input": self.prompt.prompt}
length = self.prompt.options.length
transitions = defaultdict(list)
all_words = self.prompt.prompt.split()
for i in range(len(all_words) - 1):
transitions[all_words[i]].append(all_words[i + 1])
result = [all_words[0]]
for _ in range(length - 1):
if transitions[result[-1]]:
token = random.choice(transitions[result[-1]])
else:
token = random.choice(all_words)
yield token + " "
time.sleep(0.02)
result.append(token)
self._response_json = {
"generated": " ".join(result),
"transitions": dict(transitions),
}
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
return self.Response(prompt, self, stream)
def __str__(self):
return "Markov: {}".format(self.model_id)

View file

@ -1,6 +1,5 @@
from llm import Model, Prompt, hookimpl
from llm import Model, hookimpl
import llm
from llm.errors import NeedsKeyException
from llm.utils import dicts_to_table_string
import click
import datetime
@ -140,14 +139,6 @@ class Chat(Model):
self.model_id = model_id
self.key = key
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
key = self.get_key()
if key is None:
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return self.Response(prompt, self, stream, key=key)
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)

View file

@ -1,5 +1,6 @@
from dataclasses import dataclass, asdict
import datetime
from .errors import NeedsKeyException
import time
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Set
from abc import ABC, abstractmethod
@ -141,8 +142,16 @@ class Model(ABC):
if self.key is not None:
return self.key
if self.key_env_var is not None:
return os.environ.get(self.key_env_var)
return None
key = os.environ.get(self.key_env_var)
if key:
return key
message = "No key found - add one using 'llm keys set {}'".format(
self.needs_key
)
if self.key_env_var:
message += " or set the {} environment variable".format(self.key_env_var)
raise NeedsKeyException(message)
def prompt(
self,
@ -158,7 +167,10 @@ class Model(ABC):
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
r = cast(Callable, getattr(self, "Response"))
return r(prompt, self, stream)
kwargs = {}
if self.needs_key:
kwargs["key"] = self.get_key()
return r(prompt, self, stream, **kwargs)
@abstractmethod
def __str__(self) -> str: