diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 956591d..a1d0076 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,4 +1,5 @@ -from llm import LogMessage, Model, Prompt, Response, hookimpl +from llm import Model, Prompt, hookimpl +import llm from llm.errors import NeedsKeyException from llm.utils import dicts_to_table_string import click @@ -61,56 +62,6 @@ def register_commands(cli): print("\n".join(done)) -class ChatResponse(Response): - def __init__(self, prompt, model, stream, key): - super().__init__(prompt, model, stream) - self.key = key - self._prompt_json = None - - def iter_prompt(self): - messages = [] - if self.prompt.system: - messages.append({"role": "system", "content": self.prompt.system}) - messages.append({"role": "user", "content": self.prompt.prompt}) - openai.api_key = self.key - self._prompt_json = {"messages": messages} - if self.stream: - completion = openai.ChatCompletion.create( - model=self.prompt.model.model_id, - messages=messages, - stream=True, - **not_nulls(self.prompt.options), - ) - chunks = [] - for chunk in completion: - chunks.append(chunk) - content = chunk["choices"][0].get("delta", {}).get("content") - if content is not None: - yield content - self._response_json = combine_chunks(chunks) - else: - response = openai.ChatCompletion.create( - model=self.prompt.model.model_id, - messages=messages, - stream=False, - ) - self._response_json = response.to_dict_recursive() - yield response.choices[0].message.content - - def log_message(self) -> LogMessage: - return LogMessage( - model=self.prompt.model.model_id, - prompt=self.prompt.prompt, - system=self.prompt.system, - prompt_json=self._prompt_json, - options_json=not_nulls(self.prompt.options), - response=self.text(), - response_json=self.json(), - reply_to_id=None, # TODO - chat_id=None, # TODO - ) - - class Chat(Model): needs_key = "openai" key_env_var = "OPENAI_API_KEY" @@ -150,17 +101,52 @@ class Chat(Model): return validated_logit_bias + class Response(llm.Response): + def __init__(self, prompt, model, stream, key): + super().__init__(prompt, model, stream) + self.key = key + + def iter_prompt(self): + messages = [] + if self.prompt.system: + messages.append({"role": "system", "content": self.prompt.system}) + messages.append({"role": "user", "content": self.prompt.prompt}) + openai.api_key = self.key + self._prompt_json = {"messages": messages} + if self.stream: + completion = openai.ChatCompletion.create( + model=self.prompt.model.model_id, + messages=messages, + stream=True, + **not_nulls(self.prompt.options), + ) + chunks = [] + for chunk in completion: + chunks.append(chunk) + content = chunk["choices"][0].get("delta", {}).get("content") + if content is not None: + yield content + self._response_json = combine_chunks(chunks) + else: + response = openai.ChatCompletion.create( + model=self.prompt.model.model_id, + messages=messages, + stream=False, + ) + self._response_json = response.to_dict_recursive() + yield response.choices[0].message.content + def __init__(self, model_id, key=None): self.model_id = model_id self.key = key - def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse: + def execute(self, prompt: Prompt, stream: bool = True) -> Response: key = self.get_key() if key is None: raise NeedsKeyException( "{} needs an API key, label={}".format(str(self), self.needs_key) ) - return ChatResponse(prompt, self, stream, key=key) + return self.Response(prompt, self, stream, key=key) def __str__(self): return "OpenAI Chat: {}".format(self.model_id) diff --git a/llm/models.py b/llm/models.py index 46c7114..00c1549 100644 --- a/llm/models.py +++ b/llm/models.py @@ -13,7 +13,7 @@ class Prompt: model: "Model" system: Optional[str] prompt_json: Optional[str] - options: Dict[str, Any] + options: "Model.Options" def __init__(self, prompt, model, system=None, prompt_json=None, options=None): self.prompt = prompt @@ -45,6 +45,7 @@ class LogMessage: class Response(ABC): def __init__(self, prompt: Prompt, model: "Model", stream: bool): self.prompt = prompt + self._prompt_json = None self.model = model self.stream = stream self._chunks: List[str] = [] @@ -99,10 +100,22 @@ class Response(ABC): self._force() return self._start_utcnow.isoformat() - @abstractmethod def log_message(self) -> LogMessage: - "Return a LogMessage of data to log" - pass + return LogMessage( + model=self.prompt.model.model_id, + prompt=self.prompt.prompt, + system=self.prompt.system, + prompt_json=self._prompt_json, + options_json={ + key: value + for key, value in self.prompt.options.model_dump().items() + if value is not None + }, + response=self.text(), + response_json=self.json(), + reply_to_id=None, # TODO + chat_id=None, # TODO + ) def log_to_db(self, db): message = self.log_message()