mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-15 17:23:09 +00:00
Improved how keys work, execute() now has default implementation
This commit is contained in:
parent
f193468f76
commit
6ef6b343a9
4 changed files with 79 additions and 31 deletions
32
llm/cli.py
32
llm/cli.py
|
|
@ -199,8 +199,8 @@ def prompt(
|
|||
except KeyError:
|
||||
raise click.ClickException("'{}' is not a known model".format(model_id))
|
||||
|
||||
# Provide the API key, if one is needed
|
||||
if model.needs_key and not model.key:
|
||||
# Provide the API key, if one is needed and has been provided
|
||||
if model.needs_key:
|
||||
model.key = get_key(key, model.needs_key, model.key_env_var)
|
||||
|
||||
# Validate options
|
||||
|
|
@ -220,15 +220,17 @@ def prompt(
|
|||
if not should_stream:
|
||||
validated_options["stream"] = False
|
||||
|
||||
response = model.prompt(prompt, system, **validated_options)
|
||||
|
||||
if should_stream:
|
||||
for chunk in response:
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
else:
|
||||
print(response.text())
|
||||
try:
|
||||
response = model.prompt(prompt, system, **validated_options)
|
||||
if should_stream:
|
||||
for chunk in response:
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
else:
|
||||
print(response.text())
|
||||
except Exception as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
# Log to the database
|
||||
if no_log:
|
||||
|
|
@ -543,13 +545,7 @@ def get_key(key_arg, default_key, env_var=None):
|
|||
return key_arg
|
||||
if env_var and os.environ.get(env_var):
|
||||
return os.environ[env_var]
|
||||
default = keys.get(default_key)
|
||||
if not default:
|
||||
message = "No key found - add one using 'llm keys set {}'".format(default_key)
|
||||
if env_var:
|
||||
message += " or set the {} environment variable".format(env_var)
|
||||
raise click.ClickException(message)
|
||||
return default
|
||||
return keys.get(default_key)
|
||||
|
||||
|
||||
def load_keys():
|
||||
|
|
|
|||
49
llm/default_plugins/markov.py
Normal file
49
llm/default_plugins/markov.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from llm import Model, Prompt, hookimpl
|
||||
import llm
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
@hookimpl
|
||||
def register_models(register):
|
||||
register(Markov())
|
||||
|
||||
|
||||
class Markov(Model):
|
||||
can_stream = True
|
||||
model_id = "markov"
|
||||
|
||||
class Options(Model.Options):
|
||||
length: int = 100
|
||||
|
||||
class Response(llm.Response):
|
||||
def iter_prompt(self):
|
||||
self._prompt_json = {"input": self.prompt.prompt}
|
||||
|
||||
length = self.prompt.options.length
|
||||
|
||||
transitions = defaultdict(list)
|
||||
all_words = self.prompt.prompt.split()
|
||||
for i in range(len(all_words) - 1):
|
||||
transitions[all_words[i]].append(all_words[i + 1])
|
||||
|
||||
result = [all_words[0]]
|
||||
for _ in range(length - 1):
|
||||
if transitions[result[-1]]:
|
||||
token = random.choice(transitions[result[-1]])
|
||||
else:
|
||||
token = random.choice(all_words)
|
||||
yield token + " "
|
||||
time.sleep(0.02)
|
||||
result.append(token)
|
||||
self._response_json = {
|
||||
"generated": " ".join(result),
|
||||
"transitions": dict(transitions),
|
||||
}
|
||||
|
||||
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
|
||||
return self.Response(prompt, self, stream)
|
||||
|
||||
def __str__(self):
|
||||
return "Markov: {}".format(self.model_id)
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from llm import Model, Prompt, hookimpl
|
||||
from llm import Model, hookimpl
|
||||
import llm
|
||||
from llm.errors import NeedsKeyException
|
||||
from llm.utils import dicts_to_table_string
|
||||
import click
|
||||
import datetime
|
||||
|
|
@ -140,14 +139,6 @@ class Chat(Model):
|
|||
self.model_id = model_id
|
||||
self.key = key
|
||||
|
||||
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
|
||||
key = self.get_key()
|
||||
if key is None:
|
||||
raise NeedsKeyException(
|
||||
"{} needs an API key, label={}".format(str(self), self.needs_key)
|
||||
)
|
||||
return self.Response(prompt, self, stream, key=key)
|
||||
|
||||
def __str__(self):
|
||||
return "OpenAI Chat: {}".format(self.model_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from dataclasses import dataclass, asdict
|
||||
import datetime
|
||||
from .errors import NeedsKeyException
|
||||
import time
|
||||
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Set
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -141,8 +142,16 @@ class Model(ABC):
|
|||
if self.key is not None:
|
||||
return self.key
|
||||
if self.key_env_var is not None:
|
||||
return os.environ.get(self.key_env_var)
|
||||
return None
|
||||
key = os.environ.get(self.key_env_var)
|
||||
if key:
|
||||
return key
|
||||
|
||||
message = "No key found - add one using 'llm keys set {}'".format(
|
||||
self.needs_key
|
||||
)
|
||||
if self.key_env_var:
|
||||
message += " or set the {} environment variable".format(self.key_env_var)
|
||||
raise NeedsKeyException(message)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
|
|
@ -158,7 +167,10 @@ class Model(ABC):
|
|||
|
||||
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
|
||||
r = cast(Callable, getattr(self, "Response"))
|
||||
return r(prompt, self, stream)
|
||||
kwargs = {}
|
||||
if self.needs_key:
|
||||
kwargs["key"] = self.get_key()
|
||||
return r(prompt, self, stream, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
|
|
|
|||
Loading…
Reference in a new issue