conversation_name should not have newlines, closes #110

This commit is contained in:
Simon Willison 2023-07-15 21:28:35 -07:00
parent 178af27d95
commit cb41409e2b
2 changed files with 21 additions and 11 deletions

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
import datetime
from .errors import NeedsKeyException
import re
import time
from typing import Any, Dict, Iterator, List, Optional, Set
from abc import ABC, abstractmethod
@ -9,6 +10,8 @@ import json
from pydantic import ConfigDict, BaseModel
from ulid import ULID
CONVERSATION_NAME_LENGTH = 32
@dataclass
class Prompt:
@ -125,7 +128,9 @@ class Response(ABC):
db["conversations"].insert(
{
"id": conversation.id,
"name": _truncated(self.prompt.prompt or self.prompt.system or "", 32),
"name": _conversation_name(
self.prompt.prompt or self.prompt.system or ""
),
"model": conversation.model.model_id,
},
ignore=True,
@ -273,7 +278,9 @@ class ModelWithAliases:
aliases: Set[str]
def _truncated(text, length):
if len(text) <= length:
def _conversation_name(text):
# Collapse whitespace, including newlines
text = re.sub(r"\s+", " ", text)
if len(text) <= CONVERSATION_NAME_LENGTH:
return text
return text[: length - 1] + ""
return text[: CONVERSATION_NAME_LENGTH - 1] + ""

View file

@ -153,7 +153,7 @@ def test_llm_default_prompt(
# Run the prompt
runner = CliRunner()
prompt = "three names for a pet pelican"
prompt = "three names\nfor a pet pelican"
input = None
args = ["--no-stream"]
if use_stdin:
@ -176,7 +176,7 @@ def test_llm_default_prompt(
assert len(rows) == 1
expected = {
"model": "gpt-3.5-turbo",
"prompt": "three names for a pet pelican",
"prompt": "three names\nfor a pet pelican",
"system": None,
"options_json": "{}",
"response": "Bob, Alice, Eve",
@ -186,7 +186,7 @@ def test_llm_default_prompt(
assert isinstance(row["duration_ms"], int)
assert isinstance(row["datetime_utc"], str)
assert json.loads(row["prompt_json"]) == {
"messages": [{"role": "user", "content": "three names for a pet pelican"}]
"messages": [{"role": "user", "content": "three names\nfor a pet pelican"}]
}
assert json.loads(row["response_json"]) == {
"model": "gpt-3.5-turbo",
@ -203,11 +203,11 @@ def test_llm_default_prompt(
log_json[0].items()
>= {
"model": "gpt-3.5-turbo",
"prompt": "three names for a pet pelican",
"prompt": "three names\nfor a pet pelican",
"system": None,
"prompt_json": {
"messages": [
{"role": "user", "content": "three names for a pet pelican"}
{"role": "user", "content": "three names\nfor a pet pelican"}
]
},
"options_json": {},
@ -217,6 +217,9 @@ def test_llm_default_prompt(
"usage": {},
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
},
# This doesn't have the \n after three names:
"conversation_name": "three names for a pet pelican",
"conversation_model": "gpt-3.5-turbo",
}.items()
)
@ -236,12 +239,12 @@ def test_openai_localai_configuration(mocked_localai, user_path):
config_path.write_text(EXTRA_MODELS_YAML, "utf-8")
# Run the prompt
runner = CliRunner()
prompt = "three names for a pet pelican"
prompt = "three names\nfor a pet pelican"
result = runner.invoke(cli, ["--no-stream", "--model", "orca", prompt])
assert result.exit_code == 0
assert result.output == "Bob, Alice, Eve\n"
assert json.loads(mocked_localai.last_request.text) == {
"model": "orca-mini-3b",
"messages": [{"role": "user", "content": "three names for a pet pelican"}],
"messages": [{"role": "user", "content": "three names\nfor a pet pelican"}],
"stream": False,
}