Log input tokens, output tokens and token details (#642)

* Store input_tokens, output_tokens, token_details on Response, closes #610
* llm prompt -u/--usage option
* llm logs -u/--usage option
* Docs on tracking token usage in plugins
* OpenAI default plugin logs usage
This commit is contained in:
Simon Willison 2024-11-19 20:21:59 -08:00 committed by GitHub
parent 4a059d722b
commit cfb10f4afd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 224 additions and 13 deletions

View file

@ -122,6 +122,7 @@ Options:
--key TEXT API key to use
--save TEXT Save prompt with this template name
--async Run prompt asynchronously
-u, --usage Show token usage
--help Show this message and exit.
```
@ -292,6 +293,7 @@ Options:
-m, --model TEXT Filter by model or model alias
-q, --query TEXT Search for logs matching this string
-t, --truncate Truncate long strings in output
-u, --usage Include token usage
-r, --response Just output the last response
-c, --current Show logs from the current conversation
--cid, --conversation TEXT Show logs for this conversation ID

View file

@ -159,7 +159,10 @@ CREATE TABLE [responses] (
[response_json] TEXT,
[conversation_id] TEXT REFERENCES [conversations]([id]),
[duration_ms] INTEGER,
[datetime_utc] TEXT
[datetime_utc] TEXT,
[input_tokens] INTEGER,
[output_tokens] INTEGER,
[token_details] TEXT
);
CREATE VIRTUAL TABLE [responses_fts] USING FTS5 (
[prompt],

View file

@ -167,3 +167,19 @@ for prev_response in conversation.responses:
The `response.text_or_raise()` method used there will return the text from the response or raise a `ValueError` exception if the response is an `AsyncResponse` instance that has not yet been fully resolved.
This is a slightly weird hack to work around the common need to share logic for building up the `messages` list across both sync and async models.
(advanced-model-plugins-usage)=
## Tracking token usage
Models that charge by the token should track the number of tokens used by each prompt. The ``response.set_usage()`` method can be used to record the number of tokens used by a response - these will then be made available through the Python API and logged to the SQLite database for command-line users.
`response` here is the response object that is passed to `.execute()` as an argument.
Call ``response.set_usage()`` at the end of your `.execute()` method. It accepts keyword arguments `input=`, `output=` and `details=` - all three are optional. `input` and `output` should be integers, and `details` should be a dictionary that provides additional information beyond the input and output token counts.
This example logs 15 input tokens, 340 output tokens and notes that 37 tokens were cached:
```python
response.set_usage(input=15, output=340, details={"cached": 37})
```

View file

@ -33,7 +33,7 @@ from llm import (
from .migrations import migrate
from .plugins import pm, load_plugins
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
import base64
import httpx
import pathlib
@ -203,6 +203,7 @@ def cli():
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
@click.option("-u", "--usage", is_flag=True, help="Show token usage")
def prompt(
prompt,
system,
@ -220,6 +221,7 @@ def prompt(
key,
save,
async_,
usage,
):
"""
Execute a prompt
@ -426,14 +428,24 @@ def prompt(
except Exception as ex:
raise click.ClickException(str(ex))
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
if usage:
# Show token usage to stderr in yellow
click.echo(
click.style(
"Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
),
err=True,
)
# Log to the database
if (logs_on() or log) and not no_log:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
response.log_to_db(db)
@ -754,6 +766,9 @@ LOGS_COLUMNS = """ responses.id,
responses.conversation_id,
responses.duration_ms,
responses.datetime_utc,
responses.input_tokens,
responses.output_tokens,
responses.token_details,
conversations.name as conversation_name,
conversations.model as conversation_model"""
@ -809,6 +824,7 @@ order by prompt_attachments."order"
@click.option("-m", "--model", help="Filter by model or model alias")
@click.option("-q", "--query", help="Search for logs matching this string")
@click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output")
@click.option("-u", "--usage", is_flag=True, help="Include token usage")
@click.option("-r", "--response", is_flag=True, help="Just output the last response")
@click.option(
"current_conversation",
@ -836,6 +852,7 @@ def logs_list(
model,
query,
truncate,
usage,
response,
current_conversation,
conversation_id,
@ -998,6 +1015,14 @@ def logs_list(
)
click.echo("\n## Response:\n\n{}\n".format(row["response"]))
if usage:
token_usage = token_usage_string(
row["input_tokens"],
row["output_tokens"],
json.loads(row["token_details"]) if row["token_details"] else None,
)
if token_usage:
click.echo("## Token usage:\n\n{}\n".format(token_usage))
@cli.group(

View file

@ -1,6 +1,11 @@
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
from llm.utils import (
dicts_to_table_string,
remove_dict_none_values,
logging_client,
simplify_usage_dict,
)
import click
import datetime
import httpx
@ -391,6 +396,16 @@ class _Shared:
messages.append({"role": "user", "content": attachment_message})
return messages
def set_usage(self, response, usage):
if not usage:
return
input_tokens = usage.pop("prompt_tokens")
output_tokens = usage.pop("completion_tokens")
usage.pop("total_tokens")
response.set_usage(
input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
)
def get_client(self, async_=False):
kwargs = {}
if self.api_base:
@ -445,6 +460,7 @@ class Chat(_Shared, Model):
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
usage = None
if stream:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
@ -455,6 +471,8 @@ class Chat(_Shared, Model):
chunks = []
for chunk in completion:
chunks.append(chunk)
if chunk.usage:
usage = chunk.usage.model_dump()
try:
content = chunk.choices[0].delta.content
except IndexError:
@ -469,8 +487,10 @@ class Chat(_Shared, Model):
stream=False,
**kwargs,
)
usage = completion.usage.model_dump()
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})
@ -493,6 +513,7 @@ class AsyncChat(_Shared, AsyncModel):
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client(async_=True)
usage = None
if stream:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
@ -502,6 +523,8 @@ class AsyncChat(_Shared, AsyncModel):
)
chunks = []
async for chunk in completion:
if chunk.usage:
usage = chunk.usage.model_dump()
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
@ -518,7 +541,9 @@ class AsyncChat(_Shared, AsyncModel):
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
usage = completion.usage.model_dump()
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})

View file

@ -227,3 +227,10 @@ def m012_attachments_tables(db):
),
pk=("response_id", "attachment_id"),
)
@migration
def m013_usage(db):
db["responses"].add_column("input_tokens", int)
db["responses"].add_column("output_tokens", int)
db["responses"].add_column("token_details", str)

View file

@ -18,7 +18,7 @@ from typing import (
Set,
Union,
)
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
@ -208,6 +208,20 @@ class _BaseResponse:
self._start: Optional[float] = None
self._end: Optional[float] = None
self._start_utcnow: Optional[datetime.datetime] = None
self.input_tokens: Optional[int] = None
self.output_tokens: Optional[int] = None
self.token_details: Optional[dict] = None
def set_usage(
self,
*,
input: Optional[int] = None,
output: Optional[int] = None,
details: Optional[dict] = None,
):
self.input_tokens = input
self.output_tokens = output
self.token_details = details
@classmethod
def from_row(cls, db, row):
@ -246,6 +260,11 @@ class _BaseResponse:
]
return response
def token_usage(self) -> str:
return token_usage_string(
self.input_tokens, self.output_tokens, self.token_details
)
def log_to_db(self, db):
conversation = self.conversation
if not conversation:
@ -272,11 +291,16 @@ class _BaseResponse:
for key, value in dict(self.prompt.options).items()
if value is not None
},
"response": self.text(),
"response": self.text_or_raise(),
"response_json": self.json(),
"conversation_id": conversation.id,
"duration_ms": self.duration_ms(),
"datetime_utc": self.datetime_utc(),
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"token_details": (
json.dumps(self.token_details) if self.token_details else None
),
}
db["responses"].insert(response)
# Persist any attachments - loop through with index
@ -439,6 +463,9 @@ class AsyncResponse(_BaseResponse):
response._end = self._end
response._start = self._start
response._start_utcnow = self._start_utcnow
response.input_tokens = self.input_tokens
response.output_tokens = self.output_tokens
response.token_details = self.token_details
return response
@classmethod

View file

@ -127,3 +127,29 @@ def logging_client() -> httpx.Client:
transport=_LogTransport(httpx.HTTPTransport()),
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
)
def simplify_usage_dict(d):
# Recursively remove keys with value 0 and empty dictionaries
def remove_empty_and_zero(obj):
if isinstance(obj, dict):
cleaned = {
k: remove_empty_and_zero(v)
for k, v in obj.items()
if v != 0 and v != {}
}
return {k: v for k, v in cleaned.items() if v is not None and v != {}}
return obj
return remove_empty_and_zero(d) or {}
def token_usage_string(input_tokens, output_tokens, token_details) -> str:
bits = []
if input_tokens is not None:
bits.append(f"{format(input_tokens, ',')} input")
if output_tokens is not None:
bits.append(f"{format(output_tokens, ',')} output")
if token_details:
bits.append(json.dumps(token_details))
return ", ".join(bits)

View file

@ -66,13 +66,17 @@ class MockModel(llm.Model):
def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
gathered = []
while True:
try:
messages = self._queue.pop(0)
yield from messages
for message in messages:
gathered.append(message)
yield message
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))
class AsyncMockModel(llm.AsyncModel):

View file

@ -62,6 +62,9 @@ def test_chat_basic(mock_model, logs_db):
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
},
{
"id": ANY,
@ -75,6 +78,9 @@ def test_chat_basic(mock_model, logs_db):
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 2,
"output_tokens": 1,
"token_details": None,
},
]
# Now continue that conversation
@ -116,6 +122,9 @@ def test_chat_basic(mock_model, logs_db):
"conversation_id": conversation_id,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]
@ -153,6 +162,9 @@ def test_chat_system(mock_model, logs_db):
"conversation_id": ANY,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]
@ -181,6 +193,9 @@ def test_chat_options(mock_model, logs_db):
"conversation_id": ANY,
"duration_ms": ANY,
"datetime_utc": ANY,
"input_tokens": 1,
"output_tokens": 1,
"token_details": None,
}
]

View file

@ -147,7 +147,8 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype):
@pytest.mark.parametrize("async_", (False, True))
def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_):
@pytest.mark.parametrize("usage", (None, "-u", "--usage"))
def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage):
user_path = tmpdir / "user_dir"
log_db = user_path / "logs.db"
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
@ -173,21 +174,25 @@ def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_):
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 2,
"prompt_tokens": 1000,
"completion_tokens": 2000,
"total_tokens": 12,
},
"system_fingerprint": "fp_49254d0e9b",
},
headers={"Content-Type": "application/json"},
)
runner = CliRunner()
runner = CliRunner(mix_stderr=False)
args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"]
if usage:
args.append(usage)
if async_:
args.append("--async")
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
assert result.output == "Ho ho ho\n"
if usage:
assert result.stderr == "Token usage: 1,000 input, 2,000 output\n"
# Confirm it was correctly logged
assert log_db.exists()
db = sqlite_utils.Database(str(log_db))

View file

@ -37,6 +37,8 @@ def log_path(user_path):
"model": "davinci",
"datetime_utc": (start + datetime.timedelta(seconds=i)).isoformat(),
"conversation_id": "abc123",
"input_tokens": 2,
"output_tokens": 5,
}
for i in range(100)
)
@ -46,9 +48,12 @@ def log_path(user_path):
datetime_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
def test_logs_text(log_path):
@pytest.mark.parametrize("usage", (False, True))
def test_logs_text(log_path, usage):
runner = CliRunner()
args = ["logs", "-p", str(log_path)]
if usage:
args.append("-u")
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0
output = result.output
@ -64,18 +69,24 @@ def test_logs_text(log_path):
"system\n\n"
"## Response:\n\n"
"response\n\n"
) + ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "") + (
"# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n"
"Model: **davinci**\n\n"
"## Prompt:\n\n"
"prompt\n\n"
"## Response:\n\n"
"response\n\n"
) + (
"## Token usage:\n\n2 input, 5 output\n\n" if usage else ""
) + (
"# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n"
"Model: **davinci**\n\n"
"## Prompt:\n\n"
"prompt\n\n"
"## Response:\n\n"
"response\n\n"
) + (
"## Token usage:\n\n2 input, 5 output\n\n" if usage else ""
)

View file

@ -17,6 +17,9 @@ EXPECTED = {
"conversation_id": str,
"duration_ms": int,
"datetime_utc": str,
"input_tokens": int,
"output_tokens": int,
"token_details": str,
}

42
tests/test_utils.py Normal file
View file

@ -0,0 +1,42 @@
import pytest
from llm.utils import simplify_usage_dict
@pytest.mark.parametrize(
"input_data,expected_output",
[
(
{
"prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 1,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
{"completion_tokens_details": {"audio_tokens": 1}},
),
(
{
"details": {"tokens": 5, "audio_tokens": 2},
"more_details": {"accepted_tokens": 3},
},
{
"details": {"tokens": 5, "audio_tokens": 2},
"more_details": {"accepted_tokens": 3},
},
),
({"details": {"tokens": 0, "audio_tokens": 0}, "more_details": {}}, {}),
({"level1": {"level2": {"value": 0, "another_value": {}}}}, {}),
(
{
"level1": {"level2": {"value": 0, "another_value": 1}},
"level3": {"empty_dict": {}, "valid_token": 10},
},
{"level1": {"level2": {"another_value": 1}}, "level3": {"valid_token": 10}},
),
],
)
def test_simplify_usage_dict(input_data, expected_output):
assert simplify_usage_dict(input_data) == expected_output