Moved things into inner classes, log_message is now defined on base Response

This commit is contained in:
Simon Willison 2023-07-03 21:25:19 -07:00
parent de81cc9a9e
commit 3136948408
2 changed files with 56 additions and 57 deletions

View file

@ -1,4 +1,5 @@
from llm import LogMessage, Model, Prompt, Response, hookimpl
from llm import Model, Prompt, hookimpl
import llm
from llm.errors import NeedsKeyException
from llm.utils import dicts_to_table_string
import click
@ -61,56 +62,6 @@ def register_commands(cli):
print("\n".join(done))
class ChatResponse(Response):
def __init__(self, prompt, model, stream, key):
super().__init__(prompt, model, stream)
self.key = key
self._prompt_json = None
def iter_prompt(self):
messages = []
if self.prompt.system:
messages.append({"role": "system", "content": self.prompt.system})
messages.append({"role": "user", "content": self.prompt.prompt})
openai.api_key = self.key
self._prompt_json = {"messages": messages}
if self.stream:
completion = openai.ChatCompletion.create(
model=self.prompt.model.model_id,
messages=messages,
stream=True,
**not_nulls(self.prompt.options),
)
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk["choices"][0].get("delta", {}).get("content")
if content is not None:
yield content
self._response_json = combine_chunks(chunks)
else:
response = openai.ChatCompletion.create(
model=self.prompt.model.model_id,
messages=messages,
stream=False,
)
self._response_json = response.to_dict_recursive()
yield response.choices[0].message.content
def log_message(self) -> LogMessage:
return LogMessage(
model=self.prompt.model.model_id,
prompt=self.prompt.prompt,
system=self.prompt.system,
prompt_json=self._prompt_json,
options_json=not_nulls(self.prompt.options),
response=self.text(),
response_json=self.json(),
reply_to_id=None, # TODO
chat_id=None, # TODO
)
class Chat(Model):
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
@ -150,17 +101,52 @@ class Chat(Model):
return validated_logit_bias
class Response(llm.Response):
def __init__(self, prompt, model, stream, key):
super().__init__(prompt, model, stream)
self.key = key
def iter_prompt(self):
messages = []
if self.prompt.system:
messages.append({"role": "system", "content": self.prompt.system})
messages.append({"role": "user", "content": self.prompt.prompt})
openai.api_key = self.key
self._prompt_json = {"messages": messages}
if self.stream:
completion = openai.ChatCompletion.create(
model=self.prompt.model.model_id,
messages=messages,
stream=True,
**not_nulls(self.prompt.options),
)
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk["choices"][0].get("delta", {}).get("content")
if content is not None:
yield content
self._response_json = combine_chunks(chunks)
else:
response = openai.ChatCompletion.create(
model=self.prompt.model.model_id,
messages=messages,
stream=False,
)
self._response_json = response.to_dict_recursive()
yield response.choices[0].message.content
def __init__(self, model_id, key=None):
self.model_id = model_id
self.key = key
def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse:
def execute(self, prompt: Prompt, stream: bool = True) -> Response:
key = self.get_key()
if key is None:
raise NeedsKeyException(
"{} needs an API key, label={}".format(str(self), self.needs_key)
)
return ChatResponse(prompt, self, stream, key=key)
return self.Response(prompt, self, stream, key=key)
def __str__(self):
return "OpenAI Chat: {}".format(self.model_id)

View file

@ -13,7 +13,7 @@ class Prompt:
model: "Model"
system: Optional[str]
prompt_json: Optional[str]
options: Dict[str, Any]
options: "Model.Options"
def __init__(self, prompt, model, system=None, prompt_json=None, options=None):
self.prompt = prompt
@ -45,6 +45,7 @@ class LogMessage:
class Response(ABC):
def __init__(self, prompt: Prompt, model: "Model", stream: bool):
self.prompt = prompt
self._prompt_json = None
self.model = model
self.stream = stream
self._chunks: List[str] = []
@ -99,10 +100,22 @@ class Response(ABC):
self._force()
return self._start_utcnow.isoformat()
@abstractmethod
def log_message(self) -> LogMessage:
"Return a LogMessage of data to log"
pass
return LogMessage(
model=self.prompt.model.model_id,
prompt=self.prompt.prompt,
system=self.prompt.system,
prompt_json=self._prompt_json,
options_json={
key: value
for key, value in self.prompt.options.model_dump().items()
if value is not None
},
response=self.text(),
response_json=self.json(),
reply_to_id=None, # TODO
chat_id=None, # TODO
)
def log_to_db(self, db):
message = self.log_message()