Model.stream() and .get_key() methods

This commit is contained in:
Simon Willison 2023-07-01 09:03:07 -07:00
parent 2911975548
commit ffe4b6706d
2 changed files with 23 additions and 6 deletions

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Any, Dict, Generator, Optional, Set
from abc import ABC, abstractmethod
import os
from pydantic import BaseModel
@ -54,6 +55,7 @@ class Response(ABC):
class Model(ABC):
model_id: str
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
@ -62,10 +64,25 @@ class Model(ABC):
class Config:
extra = "forbid"
def prompt(self, prompt, system=None, stream=True, **options):
def get_key(self):
if self.needs_key is None:
return None
if self.key is not None:
return self.key
if self.key_env_var is not None:
return os.environ.get(self.key_env_var)
return None
def prompt(self, prompt, system=None, **options):
return self.execute(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
stream=stream,
stream=False,
)
def stream(self, prompt, system=None, **options):
return self.execute(
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
stream=True,
)
@abstractmethod

View file

@ -53,17 +53,17 @@ class Chat(Model):
key_env_var = "OPENAI_API_KEY"
can_stream: bool = True
def __init__(self, model_id, key=None, stream=True):
def __init__(self, model_id, key=None):
self.model_id = model_id
self.stream = stream
self.key = key
def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse:
if self.key is None:
key = self.get_key()
if key is None:
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return ChatResponse(prompt, stream, key=self.key)
return ChatResponse(prompt, stream, key=key)
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)