mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-10 08:20:57 +00:00
Made load_conversation() async_ aware, closes #742
This commit is contained in:
parent
3445f9a112
commit
7bf1ea665e
4 changed files with 97 additions and 16 deletions
|
|
@ -4,6 +4,7 @@ from .errors import (
|
|||
NeedsKeyException,
|
||||
)
|
||||
from .models import (
|
||||
AsyncConversation,
|
||||
AsyncKeyModel,
|
||||
AsyncModel,
|
||||
AsyncResponse,
|
||||
|
|
@ -29,6 +30,7 @@ import pathlib
|
|||
import struct
|
||||
|
||||
__all__ = [
|
||||
"AsyncConversation",
|
||||
"AsyncKeyModel",
|
||||
"AsyncResponse",
|
||||
"Attachment",
|
||||
|
|
|
|||
14
llm/cli.py
14
llm/cli.py
|
|
@ -7,6 +7,7 @@ import json
|
|||
import re
|
||||
from llm import (
|
||||
Attachment,
|
||||
AsyncConversation,
|
||||
AsyncKeyModel,
|
||||
AsyncResponse,
|
||||
Collection,
|
||||
|
|
@ -32,6 +33,7 @@ from llm import (
|
|||
set_default_embedding_model,
|
||||
remove_alias,
|
||||
)
|
||||
from llm.models import _BaseConversation
|
||||
|
||||
from .migrations import migrate
|
||||
from .plugins import pm, load_plugins
|
||||
|
|
@ -363,7 +365,7 @@ def prompt(
|
|||
if conversation_id or _continue:
|
||||
# Load the conversation - loads most recent if no ID provided
|
||||
try:
|
||||
conversation = load_conversation(conversation_id)
|
||||
conversation = load_conversation(conversation_id, async_=async_)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
|
|
@ -656,7 +658,9 @@ def chat(
|
|||
print("")
|
||||
|
||||
|
||||
def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
|
||||
def load_conversation(
|
||||
conversation_id: Optional[str], async_=False
|
||||
) -> Optional[_BaseConversation]:
|
||||
db = sqlite_utils.Database(logs_db_path())
|
||||
migrate(db)
|
||||
if conversation_id is None:
|
||||
|
|
@ -673,11 +677,13 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
|
|||
"No conversation found with id={}".format(conversation_id)
|
||||
)
|
||||
# Inflate that conversation
|
||||
conversation = Conversation.from_row(row)
|
||||
conversation_class = AsyncConversation if async_ else Conversation
|
||||
response_class = AsyncResponse if async_ else Response
|
||||
conversation = conversation_class.from_row(row)
|
||||
for response in db["responses"].rows_where(
|
||||
"conversation_id = ?", [conversation_id]
|
||||
):
|
||||
conversation.responses.append(Response.from_row(db, response))
|
||||
conversation.responses.append(response_class.from_row(db, response))
|
||||
return conversation
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -131,14 +131,9 @@ class _BaseConversation:
|
|||
responses: List["_BaseResponse"] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_model
|
||||
|
||||
return cls(
|
||||
model=get_model(row["model"]),
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
)
|
||||
@abstractmethod
|
||||
def from_row(cls, row: Any) -> "_BaseConversation":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -167,6 +162,16 @@ class Conversation(_BaseConversation):
|
|||
key=key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_model
|
||||
|
||||
return cls(
|
||||
model=get_model(row["model"]),
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
count = len(self.responses)
|
||||
s = "s" if count == 1 else ""
|
||||
|
|
@ -199,6 +204,16 @@ class AsyncConversation(_BaseConversation):
|
|||
key=key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row):
|
||||
from llm import get_async_model
|
||||
|
||||
return cls(
|
||||
model=get_async_model(row["model"]),
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
count = len(self.responses)
|
||||
s = "s" if count == 1 else ""
|
||||
|
|
@ -251,10 +266,13 @@ class _BaseResponse:
|
|||
self.token_details = details
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, db, row):
|
||||
from llm import get_model
|
||||
def from_row(cls, db, row, _async=False):
|
||||
from llm import get_model, get_async_model
|
||||
|
||||
model = get_model(row["model"])
|
||||
if _async:
|
||||
model = get_async_model(row["model"])
|
||||
else:
|
||||
model = get_model(row["model"])
|
||||
|
||||
response = cls(
|
||||
model=model,
|
||||
|
|
@ -427,7 +445,7 @@ class Response(_BaseResponse):
|
|||
yield chunk
|
||||
self._chunks.append(chunk)
|
||||
else:
|
||||
raise ValueError("self.model must be a Model or KeyModel")
|
||||
raise Exception("self.model must be a Model or KeyModel")
|
||||
|
||||
if self.conversation:
|
||||
self.conversation.responses.append(self)
|
||||
|
|
@ -446,6 +464,10 @@ class AsyncResponse(_BaseResponse):
|
|||
model: "AsyncModel"
|
||||
conversation: Optional["AsyncConversation"] = None
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, db, row, _async=False):
|
||||
return super().from_row(db, row, _async=True)
|
||||
|
||||
async def on_done(self, callback):
|
||||
if not self._done:
|
||||
self.done_callbacks.append(callback)
|
||||
|
|
|
|||
|
|
@ -398,6 +398,57 @@ def test_llm_default_prompt(
|
|||
)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
||||
@pytest.mark.parametrize("async_", (False, True))
|
||||
def test_llm_prompt_continue(httpx_mock, user_path, async_):
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {},
|
||||
"choices": [{"message": {"content": "Bob, Alice, Eve"}}],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
method="POST",
|
||||
url="https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {},
|
||||
"choices": [{"message": {"content": "Terry"}}],
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
log_path = user_path / "logs.db"
|
||||
log_db = sqlite_utils.Database(str(log_path))
|
||||
log_db["responses"].delete_where()
|
||||
|
||||
# First prompt
|
||||
runner = CliRunner()
|
||||
args = ["three names \nfor a pet pelican", "--no-stream"] + (
|
||||
["--async"] if async_ else []
|
||||
)
|
||||
result = runner.invoke(cli, args, catch_exceptions=False)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert result.output == "Bob, Alice, Eve\n"
|
||||
|
||||
# Should be logged
|
||||
rows = list(log_db["responses"].rows)
|
||||
assert len(rows) == 1
|
||||
|
||||
# Now ask a follow-up
|
||||
args2 = ["one more", "-c", "--no-stream"] + (["--async"] if async_ else [])
|
||||
result2 = runner.invoke(cli, args2, catch_exceptions=False)
|
||||
assert result2.exit_code == 0, result2.output
|
||||
assert result2.output == "Terry\n"
|
||||
|
||||
rows = list(log_db["responses"].rows)
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expect_just_code",
|
||||
(
|
||||
|
|
|
|||
Loading…
Reference in a new issue