mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-05 04:14:53 +00:00
Implemented PaLM 2, to test out new plugin hook - refs #20
This commit is contained in:
parent
5103f77c40
commit
5e056fad8a
5 changed files with 61 additions and 5 deletions
10
llm/cli.py
10
llm/cli.py
|
|
@ -192,11 +192,17 @@ def prompt(
|
|||
if model.needs_key and not model.key:
|
||||
model.key = get_key(key, model.needs_key, model.key_env_var)
|
||||
|
||||
prompt_kwargs = {}
|
||||
if model.can_stream:
|
||||
prompt_kwargs = {"stream": not no_stream}
|
||||
else:
|
||||
no_stream = False
|
||||
|
||||
if no_stream:
|
||||
chunk = list(model.prompt(prompt, system, stream=False))[0]
|
||||
chunk = list(model.prompt(prompt, system, **prompt_kwargs))[0]
|
||||
print(chunk)
|
||||
else:
|
||||
for chunk in model.prompt(prompt, system):
|
||||
for chunk in model.prompt(prompt, system, **prompt_kwargs):
|
||||
print(chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print("")
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class Model(ABC):
|
|||
model_id: str
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
can_stream: bool = False
|
||||
|
||||
class Options(BaseModel):
|
||||
class Config:
|
||||
|
|
|
|||
|
|
@ -14,10 +14,9 @@ def register_models(register):
|
|||
|
||||
class ChatResponse(Response):
|
||||
def __init__(self, prompt, stream, key):
|
||||
self.prompt = prompt
|
||||
super().__init__(prompt)
|
||||
self.stream = stream
|
||||
self.key = key
|
||||
super().__init__(prompt)
|
||||
|
||||
def iter_prompt(self):
|
||||
messages = []
|
||||
|
|
@ -52,6 +51,7 @@ class ChatResponse(Response):
|
|||
class Chat(Model):
|
||||
needs_key = "openai"
|
||||
key_env_var = "OPENAI_API_KEY"
|
||||
can_stream: bool = True
|
||||
|
||||
def __init__(self, model_id, key=None, stream=True):
|
||||
self.model_id = model_id
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Dict, List
|
|||
from . import hookspecs
|
||||
from .models import ModelWithAliases, Model
|
||||
|
||||
DEFAULT_PLUGINS = ("llm.openai_models",)
|
||||
DEFAULT_PLUGINS = ("llm.openai_models", "llm.vertex_models")
|
||||
|
||||
pm = pluggy.PluginManager("llm")
|
||||
pm.add_hookspecs(hookspecs)
|
||||
|
|
|
|||
49
llm/vertex_models.py
Normal file
49
llm/vertex_models.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from . import Model, Prompt, Response, hookimpl
|
||||
from .errors import NeedsKeyException
|
||||
import requests
|
||||
|
||||
|
||||
@hookimpl
|
||||
def register_models(register):
|
||||
register(Vertex("text-bison-001"), aliases=("palm2",))
|
||||
|
||||
|
||||
class VertexResponse(Response):
|
||||
def __init__(self, prompt, key):
|
||||
self.key = key
|
||||
super().__init__(prompt)
|
||||
|
||||
def iter_prompt(self):
|
||||
url = (
|
||||
f"https://generativelanguage.googleapis.com/v1beta2/models/{self.prompt.model.model_id}:generateText"
|
||||
f"?key={self.key}"
|
||||
)
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"prompt": {"text": self.prompt.prompt}},
|
||||
)
|
||||
data = response.json()
|
||||
candidate = data["candidates"][0]
|
||||
self._debug = {"safetyRatings": candidate["safetyRatings"]}
|
||||
self._done = True
|
||||
yield candidate["output"]
|
||||
|
||||
|
||||
class Vertex(Model):
|
||||
needs_key = "vertex"
|
||||
|
||||
def __init__(self, model_id, key=None):
|
||||
self.model_id = model_id
|
||||
self.key = key
|
||||
|
||||
def execute(self, prompt: Prompt, stream: bool) -> VertexResponse:
|
||||
# ignore stream, since we cannot stream
|
||||
if self.key is None:
|
||||
raise NeedsKeyException(
|
||||
"{} needs an API key, label={}".format(str(self), self.needs_key)
|
||||
)
|
||||
return VertexResponse(prompt, key=self.key)
|
||||
|
||||
def __str__(self):
|
||||
return "Vertex Chat: {}".format(self.model_id)
|
||||
Loading…
Reference in a new issue