diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 4ec7230..aff8c26 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -61,8 +61,8 @@ def register_commands(cli): class ChatResponse(Response): - def __init__(self, prompt, stream, key): - super().__init__(prompt) + def __init__(self, prompt, model, stream, key): + super().__init__(prompt, model) self.stream = stream self.key = key @@ -111,7 +111,7 @@ class Chat(Model): raise NeedsKeyException( "{} needs an API key, label={}".format(str(self), self.needs_key) ) - return ChatResponse(prompt, stream, key=key) + return ChatResponse(prompt, self, stream, key=key) def __str__(self): return "OpenAI Chat: {}".format(self.model_id) diff --git a/llm/default_plugins/vertex_models.py b/llm/default_plugins/vertex_models.py index db40400..c59130c 100644 --- a/llm/default_plugins/vertex_models.py +++ b/llm/default_plugins/vertex_models.py @@ -9,9 +9,9 @@ def register_models(register): class VertexResponse(Response): - def __init__(self, prompt, key): + def __init__(self, prompt, model, key): self.key = key - super().__init__(prompt) + super().__init__(prompt, model) def iter_prompt(self): url = ( @@ -43,7 +43,7 @@ class Vertex(Model): raise NeedsKeyException( "{} needs an API key, label={}".format(str(self), self.needs_key) ) - return VertexResponse(prompt, key=self.key) + return VertexResponse(prompt, self, key=self.key) def __str__(self): return "Vertex Chat: {}".format(self.model_id) diff --git a/llm/models.py b/llm/models.py index 0c0f9f8..01baae8 100644 --- a/llm/models.py +++ b/llm/models.py @@ -26,8 +26,9 @@ class OptionsError(Exception): class Response(ABC): - def __init__(self, prompt: Prompt): + def __init__(self, prompt: Prompt, model: "Model"): self.prompt = prompt + self.model = model self._chunks = [] self._debug = {} self._done = False