Pass model to the Response

This commit is contained in:
Simon Willison 2023-07-01 11:29:41 -07:00
parent 6ef52172b0
commit 9afc758cd7
3 changed files with 8 additions and 7 deletions

View file

@ -61,8 +61,8 @@ def register_commands(cli):
class ChatResponse(Response):
def __init__(self, prompt, stream, key):
super().__init__(prompt)
def __init__(self, prompt, model, stream, key):
super().__init__(prompt, model)
self.stream = stream
self.key = key
@ -111,7 +111,7 @@ class Chat(Model):
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return ChatResponse(prompt, stream, key=key)
return ChatResponse(prompt, self, stream, key=key)
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)

View file

@ -9,9 +9,9 @@ def register_models(register):
class VertexResponse(Response):
def __init__(self, prompt, key):
def __init__(self, prompt, model, key):
self.key = key
super().__init__(prompt)
super().__init__(prompt, model)
def iter_prompt(self):
url = (
@ -43,7 +43,7 @@ class Vertex(Model):
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return VertexResponse(prompt, key=self.key)
return VertexResponse(prompt, self, key=self.key)
def __str__(self):
return "Vertex Chat: {}".format(self.model_id)

View file

@ -26,8 +26,9 @@ class OptionsError(Exception):
class Response(ABC):
def __init__(self, prompt: Prompt):
def __init__(self, prompt: Prompt, model: "Model"):
self.prompt = prompt
self.model = model
self._chunks = []
self._debug = {}
self._done = False