Track usage on OpenAI stream requests, closes #591

This commit is contained in:
Simon Willison 2024-10-28 17:40:40 -07:00
parent ba1ccb3a4a
commit 389acdf52c

View file

@ -346,7 +346,7 @@ class Chat(Model):
)
messages.append({"role": "user", "content": attachment_message})
kwargs = self.build_kwargs(prompt)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.chat.completions.create(
@ -358,7 +358,10 @@ class Chat(Model):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk.choices[0].delta.content
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
@ -395,13 +398,15 @@ class Chat(Model):
kwargs["http_client"] = logging_client()
return openai.OpenAI(**kwargs)
def build_kwargs(self, prompt):
def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
if stream:
kwargs["stream_options"] = {"include_usage": True}
return kwargs
@ -431,7 +436,7 @@ class Completion(Chat):
messages.append(prev_response.prompt.prompt)
messages.append(prev_response.text())
messages.append(prompt.prompt)
kwargs = self.build_kwargs(prompt)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.completions.create(
@ -443,7 +448,10 @@ class Completion(Chat):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk.choices[0].text
try:
content = chunk.choices[0].text
except IndexError:
content = None
if content is not None:
yield content
combined = combine_chunks(chunks)
@ -472,8 +480,11 @@ def combine_chunks(chunks: List) -> dict:
# If any of them have log probability, we're going to persist
# those later on
logprobs = []
usage = {}
for item in chunks:
if item.usage:
usage = dict(item.usage)
for choice in item.choices:
if choice.logprobs and hasattr(choice.logprobs, "top_logprobs"):
logprobs.append(
@ -497,6 +508,7 @@ def combine_chunks(chunks: List) -> dict:
"content": content,
"role": role,
"finish_reason": finish_reason,
"usage": usage,
}
if logprobs:
combined["logprobs"] = logprobs