Move new truncate_string to llm.utils, add tests

Refs #759
This commit is contained in:
Simon Willison 2025-03-22 16:30:34 -07:00
parent 8244091854
commit 4a45bc6efc
3 changed files with 148 additions and 26 deletions

View file

@ -40,17 +40,18 @@ from llm.models import _BaseConversation
from .migrations import migrate
from .plugins import pm, load_plugins
from .utils import (
extract_fenced_code_block,
find_unused_key,
make_schema_id,
mimetype_from_path,
mimetype_from_string,
token_usage_string,
extract_fenced_code_block,
make_schema_id,
multi_schema,
output_rows_as_json,
resolve_schema_input,
schema_summary,
multi_schema,
schema_dsl,
find_unused_key,
schema_summary,
token_usage_string,
truncate_string,
)
import base64
import httpx
@ -1169,8 +1170,8 @@ def logs_list(
for row in rows:
if truncate:
row["prompt"] = _truncate_string(row["prompt"])
row["response"] = _truncate_string(row["response"])
row["prompt"] = truncate_string(row["prompt"] or "")
row["response"] = truncate_string(row["response"] or "")
# Either decode or remove all JSON keys
keys = list(row.keys())
for key in keys:
@ -1208,9 +1209,11 @@ def logs_list(
should_show_conversation = True
for row in rows:
if short:
system = _truncate_string(row["system"], 120, normalize_whitespace=True)
prompt = _truncate_string(
row["prompt"], 120, normalize_whitespace=True, keep_end=True
system = truncate_string(
row["system"] or "", 120, normalize_whitespace=True
)
prompt = truncate_string(
row["prompt"] or "", 120, normalize_whitespace=True, keep_end=True
)
cid = row["conversation_id"]
attachments = attachments_by_id.get(row["id"])
@ -2353,20 +2356,6 @@ def template_dir():
return path
def _truncate_string(text, max_length=100, normalize_whitespace=False, keep_end=False):
if not text:
return text
if normalize_whitespace:
text = re.sub(r"\s+", " ", text)
if len(text) <= max_length:
return text
if keep_end:
# Find a reasonable cutoff for the start and end portions
cutoff = (max_length - 6) // 2
return text[:cutoff] + "... " + text[-cutoff:]
return text[: max_length - 3] + "..."
def logs_db_path():
return user_dir() / "logs.db"

View file

@ -391,3 +391,43 @@ def find_unused_key(item: dict, key: str) -> str:
while key in item:
key += "_"
return key
def truncate_string(
text: str,
max_length: int = 100,
normalize_whitespace: bool = False,
keep_end: bool = False,
) -> str:
"""
Truncate a string to a maximum length, with options to normalize whitespace and keep both start and end.
Args:
text: The string to truncate
max_length: Maximum length of the result string
normalize_whitespace: If True, replace all whitespace with a single space
keep_end: If True, keep both beginning and end of string
Returns:
Truncated string
"""
if not text:
return text
if normalize_whitespace:
text = re.sub(r"\s+", " ", text)
if len(text) <= max_length:
return text
# Minimum sensible length for keep_end is 9 characters: "a... z"
min_keep_end_length = 9
if keep_end and max_length >= min_keep_end_length:
# Calculate how much text to keep at each end
# Subtract 5 for the "... " separator
cutoff = (max_length - 5) // 2
return text[:cutoff] + "... " + text[-cutoff:]
else:
# Fall back to simple truncation for very small max_length
return text[: max_length - 3] + "..."

View file

@ -1,5 +1,10 @@
import pytest
from llm.utils import simplify_usage_dict, extract_fenced_code_block, schema_dsl
from llm.utils import (
simplify_usage_dict,
extract_fenced_code_block,
truncate_string,
schema_dsl,
)
@pytest.mark.parametrize(
@ -234,3 +239,91 @@ def test_schema_dsl_multi():
},
"required": ["items"],
}
@pytest.mark.parametrize(
"text, max_length, normalize_whitespace, keep_end, expected",
[
# Basic truncation tests
("Hello, world!", 100, False, False, "Hello, world!"),
("Hello, world!", 5, False, False, "He..."),
("", 10, False, False, ""),
(None, 10, False, False, None),
# Normalize whitespace tests
("Hello world!", 100, True, False, "Hello world!"),
("Hello \n\t world!", 100, True, False, "Hello world!"),
("Hello world!", 5, True, False, "He..."),
# Keep end tests
("Hello, world!", 10, False, True, "He... d!"),
("Hello, world!", 7, False, False, "Hell..."), # Now using regular truncation
("1234567890", 7, False, False, "1234..."), # Now using regular truncation
# Combinations of parameters
("Hello world!", 10, True, True, "He... d!"),
# Note: After normalization, "Hello world!" is exactly 12 chars, so no truncation
("Hello \n\t world!", 12, True, True, "Hello world!"),
# Edge cases
("12345", 5, False, False, "12345"),
("123456", 5, False, False, "12..."),
("12345", 5, False, True, "12345"), # Unchanged for exact fit
("123456", 5, False, False, "12..."), # Regular truncation for small max_length
# Very long string
("A" * 200, 10, False, False, "AAAAAAA..."),
("A" * 200, 10, False, True, "AA... AA"), # keep_end with adequate length
# Exact boundary cases
("123456789", 9, False, False, "123456789"), # Exact fit
("1234567890", 9, False, False, "123456..."), # Simple truncation
("123456789", 9, False, True, "123456789"), # Exact fit with keep_end
("1234567890", 9, False, True, "12... 90"), # keep_end truncation
# Minimum sensible length tests for keep_end
(
"1234567890",
8,
False,
True,
"12345...",
), # Too small for keep_end, use regular
("1234567890", 9, False, True, "12... 90"), # Just enough for keep_end
],
)
def test_truncate_string(text, max_length, normalize_whitespace, keep_end, expected):
"""Test the truncate_string function with various inputs and parameters."""
result = truncate_string(
text=text,
max_length=max_length,
normalize_whitespace=normalize_whitespace,
keep_end=keep_end,
)
assert result == expected
@pytest.mark.parametrize(
"text, max_length, keep_end, prefix_len, expected_full",
[
# Test cases when the length is just right (string fits)
("0123456789", 10, True, None, "0123456789"),
# Test cases with enough room for the ellipsis
("012345678901234", 14, True, 4, "0123... 1234"),
# Test cases with different cutoffs
("abcdefghijklmnopqrstuvwxyz", 10, True, 2, "ab... yz"),
("abcdefghijklmnopqrstuvwxyz", 12, True, 3, "abc... xyz"),
# Test cases below minimum threshold
("abcdefghijklmnopqrstuvwxyz", 8, True, None, "abcde..."),
],
)
def test_test_truncate_string_keep_end(
text, max_length, keep_end, prefix_len, expected_full
):
"""Test the specific behavior of the keep_end parameter."""
result = truncate_string(
text=text,
max_length=max_length,
keep_end=keep_end,
)
assert result == expected_full
# Only check prefix/suffix when we expect truncation with keep_end
if prefix_len is not None and len(text) > max_length and max_length >= 9:
assert result[:prefix_len] == text[:prefix_len]
assert result[-prefix_len:] == text[-prefix_len:]
assert "... " in result