Implemented PaLM 2, to test out new plugin hook - refs #20

This commit is contained in:
Simon Willison 2023-06-26 08:25:01 -07:00
parent 5103f77c40
commit 5e056fad8a
5 changed files with 61 additions and 5 deletions

View file

@ -192,11 +192,17 @@ def prompt(
if model.needs_key and not model.key:
model.key = get_key(key, model.needs_key, model.key_env_var)
prompt_kwargs = {}
if model.can_stream:
prompt_kwargs = {"stream": not no_stream}
else:
no_stream = False
if no_stream:
chunk = list(model.prompt(prompt, system, stream=False))[0]
chunk = list(model.prompt(prompt, system, **prompt_kwargs))[0]
print(chunk)
else:
for chunk in model.prompt(prompt, system):
for chunk in model.prompt(prompt, system, **prompt_kwargs):
print(chunk, end="")
sys.stdout.flush()
print("")

View file

@ -56,6 +56,7 @@ class Model(ABC):
model_id: str
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
class Options(BaseModel):
class Config:

View file

@ -14,10 +14,9 @@ def register_models(register):
class ChatResponse(Response):
def __init__(self, prompt, stream, key):
self.prompt = prompt
super().__init__(prompt)
self.stream = stream
self.key = key
super().__init__(prompt)
def iter_prompt(self):
messages = []
@ -52,6 +51,7 @@ class ChatResponse(Response):
class Chat(Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
can_stream: bool = True
def __init__(self, model_id, key=None, stream=True):
self.model_id = model_id

View file

@ -5,7 +5,7 @@ from typing import Dict, List
from . import hookspecs
from .models import ModelWithAliases, Model
DEFAULT_PLUGINS = ("llm.openai_models",)
DEFAULT_PLUGINS = ("llm.openai_models", "llm.vertex_models")
pm = pluggy.PluginManager("llm")
pm.add_hookspecs(hookspecs)

49
llm/vertex_models.py Normal file
View file

@ -0,0 +1,49 @@
from . import Model, Prompt, Response, hookimpl
from .errors import NeedsKeyException
import requests
@hookimpl
def register_models(register):
register(Vertex("text-bison-001"), aliases=("palm2",))
class VertexResponse(Response):
def __init__(self, prompt, key):
self.key = key
super().__init__(prompt)
def iter_prompt(self):
url = (
f"https://generativelanguage.googleapis.com/v1beta2/models/{self.prompt.model.model_id}:generateText"
f"?key={self.key}"
)
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json={"prompt": {"text": self.prompt.prompt}},
)
data = response.json()
candidate = data["candidates"][0]
self._debug = {"safetyRatings": candidate["safetyRatings"]}
self._done = True
yield candidate["output"]
class Vertex(Model):
needs_key = "vertex"
def __init__(self, model_id, key=None):
self.model_id = model_id
self.key = key
def execute(self, prompt: Prompt, stream: bool) -> VertexResponse:
# ignore stream, since we cannot stream
if self.key is None:
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return VertexResponse(prompt, key=self.key)
def __str__(self):
return "Vertex Chat: {}".format(self.model_id)