mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
parent
5162cabbe1
commit
a50de8b57a
3 changed files with 48 additions and 14 deletions
24
llm/cli.py
24
llm/cli.py
|
|
@ -875,29 +875,25 @@ def prompt(
|
|||
raise
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
if isinstance(response, ChainResponse):
|
||||
responses = response._responses
|
||||
else:
|
||||
responses = [response]
|
||||
|
||||
for response in responses:
|
||||
if isinstance(response, AsyncResponse):
|
||||
response = asyncio.run(response.to_sync_response())
|
||||
|
||||
if usage:
|
||||
if usage:
|
||||
if isinstance(response, ChainResponse):
|
||||
responses = response._responses
|
||||
else:
|
||||
responses = [response]
|
||||
for response_object in responses:
|
||||
# Show token usage to stderr in yellow
|
||||
click.echo(
|
||||
click.style(
|
||||
"Token usage: {}".format(response.token_usage()),
|
||||
"Token usage: {}".format(response_object.token_usage()),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
|
||||
# Log to the database
|
||||
if (logs_on() or log) and not no_log:
|
||||
response.log_to_db(db)
|
||||
# Log to the database
|
||||
if (logs_on() or log) and not no_log:
|
||||
response.log_to_db(db)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
|||
|
|
@ -1134,6 +1134,12 @@ class _BaseChainResponse:
|
|||
def text(self) -> str:
|
||||
return "".join(self)
|
||||
|
||||
def log_to_db(self, db):
|
||||
for response in self._responses:
|
||||
if isinstance(response, AsyncResponse):
|
||||
response = asyncio.run(response.to_sync_response())
|
||||
response.log_to_db(db)
|
||||
|
||||
|
||||
class ChainResponse(_BaseChainResponse):
|
||||
"Know how to chain multiple responses e.g. for tool calls"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import llm
|
||||
from llm.migrations import migrate
|
||||
import os
|
||||
import pytest
|
||||
import sqlite_utils
|
||||
|
||||
|
||||
API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey"
|
||||
|
|
@ -28,3 +30,33 @@ def test_tool_use_basic(vcr):
|
|||
assert len(second.prompt.tool_results) == 1
|
||||
assert second.prompt.tool_results[0].name == "multiply"
|
||||
assert second.prompt.tool_results[0].output == "2869461"
|
||||
|
||||
# Test writing to the database
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
migrate(db)
|
||||
chain_response.log_to_db(db)
|
||||
assert set(db.table_names()).issuperset(
|
||||
{"tools", "tool_responses", "tool_calls", "tool_results"}
|
||||
)
|
||||
|
||||
responses = list(db["responses"].rows)
|
||||
assert len(responses) == 2
|
||||
first_response, second_response = responses
|
||||
|
||||
tools = list(db["tools"].rows)
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "multiply"
|
||||
assert tools[0]["description"] == "Multiply two numbers."
|
||||
|
||||
tool_results = list(db["tool_results"].rows)
|
||||
tool_calls = list(db["tool_calls"].rows)
|
||||
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["response_id"] == first_response["id"]
|
||||
assert tool_calls[0]["name"] == "multiply"
|
||||
assert tool_calls[0]["arguments"] == '{"a": 1231, "b": 2331}'
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0]["response_id"] == second_response["id"]
|
||||
assert tool_results[0]["output"] == "2869461"
|
||||
assert tool_results[0]["tool_call_id"] == tool_calls[0]["tool_call_id"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue