From 6ef6b343a9eaee1651bacd7517a9b50549509160 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 5 Jul 2023 18:25:57 -0700 Subject: [PATCH] Improved how keys work, execute() now has default implementation --- llm/cli.py | 32 ++++++++---------- llm/default_plugins/markov.py | 49 ++++++++++++++++++++++++++++ llm/default_plugins/openai_models.py | 11 +------ llm/models.py | 18 ++++++++-- 4 files changed, 79 insertions(+), 31 deletions(-) create mode 100644 llm/default_plugins/markov.py diff --git a/llm/cli.py b/llm/cli.py index ea3e161..2164367 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -199,8 +199,8 @@ def prompt( except KeyError: raise click.ClickException("'{}' is not a known model".format(model_id)) - # Provide the API key, if one is needed - if model.needs_key and not model.key: + # Provide the API key, if one is needed and has been provided + if model.needs_key: model.key = get_key(key, model.needs_key, model.key_env_var) # Validate options @@ -220,15 +220,17 @@ def prompt( if not should_stream: validated_options["stream"] = False - response = model.prompt(prompt, system, **validated_options) - - if should_stream: - for chunk in response: - print(chunk, end="") - sys.stdout.flush() - print("") - else: - print(response.text()) + try: + response = model.prompt(prompt, system, **validated_options) + if should_stream: + for chunk in response: + print(chunk, end="") + sys.stdout.flush() + print("") + else: + print(response.text()) + except Exception as ex: + raise click.ClickException(str(ex)) # Log to the database if no_log: @@ -543,13 +545,7 @@ def get_key(key_arg, default_key, env_var=None): return key_arg if env_var and os.environ.get(env_var): return os.environ[env_var] - default = keys.get(default_key) - if not default: - message = "No key found - add one using 'llm keys set {}'".format(default_key) - if env_var: - message += " or set the {} environment variable".format(env_var) - raise click.ClickException(message) - return default + return keys.get(default_key) def load_keys(): diff --git a/llm/default_plugins/markov.py b/llm/default_plugins/markov.py new file mode 100644 index 0000000..9ae46b9 --- /dev/null +++ b/llm/default_plugins/markov.py @@ -0,0 +1,49 @@ +from llm import Model, Prompt, hookimpl +import llm +from collections import defaultdict +import random +import time + + +@hookimpl +def register_models(register): + register(Markov()) + + +class Markov(Model): + can_stream = True + model_id = "markov" + + class Options(Model.Options): + length: int = 100 + + class Response(llm.Response): + def iter_prompt(self): + self._prompt_json = {"input": self.prompt.prompt} + + length = self.prompt.options.length + + transitions = defaultdict(list) + all_words = self.prompt.prompt.split() + for i in range(len(all_words) - 1): + transitions[all_words[i]].append(all_words[i + 1]) + + result = [all_words[0]] + for _ in range(length - 1): + if transitions[result[-1]]: + token = random.choice(transitions[result[-1]]) + else: + token = random.choice(all_words) + yield token + " " + time.sleep(0.02) + result.append(token) + self._response_json = { + "generated": " ".join(result), + "transitions": dict(transitions), + } + + def execute(self, prompt: Prompt, stream: bool = True) -> Response: + return self.Response(prompt, self, stream) + + def __str__(self): + return "Markov: {}".format(self.model_id) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index a1d0076..3c6202f 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,6 +1,5 @@ -from llm import Model, Prompt, hookimpl +from llm import Model, hookimpl import llm -from llm.errors import NeedsKeyException from llm.utils import dicts_to_table_string import click import datetime @@ -140,14 +139,6 @@ class Chat(Model): self.model_id = model_id self.key = key - def execute(self, prompt: Prompt, stream: bool = True) -> Response: - key = self.get_key() - if key is None: - raise NeedsKeyException( - "{} needs an API key, label={}".format(str(self), self.needs_key) - ) - return self.Response(prompt, self, stream, key=key) - def __str__(self): return "OpenAI Chat: {}".format(self.model_id) diff --git a/llm/models.py b/llm/models.py index b3d1d30..bdecb2a 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, asdict import datetime +from .errors import NeedsKeyException import time from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Set from abc import ABC, abstractmethod @@ -141,8 +142,16 @@ class Model(ABC): if self.key is not None: return self.key if self.key_env_var is not None: - return os.environ.get(self.key_env_var) - return None + key = os.environ.get(self.key_env_var) + if key: + return key + + message = "No key found - add one using 'llm keys set {}'".format( + self.needs_key + ) + if self.key_env_var: + message += " or set the {} environment variable".format(self.key_env_var) + raise NeedsKeyException(message) def prompt( self, @@ -158,7 +167,10 @@ class Model(ABC): def execute(self, prompt: Prompt, stream: bool = True) -> Response: r = cast(Callable, getattr(self, "Response")) - return r(prompt, self, stream) + kwargs = {} + if self.needs_key: + kwargs["key"] = self.get_key() + return r(prompt, self, stream, **kwargs) @abstractmethod def __str__(self) -> str: