mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
94 lines
3 KiB
Python
94 lines
3 KiB
Python
import llm
|
|
from llm.migrations import migrate
|
|
import os
|
|
import pytest
|
|
import sqlite_utils
|
|
|
|
|
|
API_KEY = os.environ.get("PYTEST_OPENAI_API_KEY", None) or "badkey"
|
|
|
|
|
|
@pytest.mark.vcr
|
|
def test_tool_use_basic(vcr):
|
|
model = llm.get_model("gpt-4o-mini")
|
|
|
|
def multiply(a: int, b: int) -> int:
|
|
"""Multiply two numbers."""
|
|
return a * b
|
|
|
|
chain_response = model.chain("What is 1231 * 2331?", tools=[multiply], key=API_KEY)
|
|
|
|
output = "".join(chain_response)
|
|
|
|
assert output == "The result of \\( 1231 \\times 2331 \\) is \\( 2,869,461 \\)."
|
|
|
|
first, second = chain_response._responses
|
|
|
|
assert first.prompt.prompt == "What is 1231 * 2331?"
|
|
assert first.prompt.tools[0].name == "multiply"
|
|
|
|
assert len(second.prompt.tool_results) == 1
|
|
assert second.prompt.tool_results[0].name == "multiply"
|
|
assert second.prompt.tool_results[0].output == "2869461"
|
|
|
|
# Test writing to the database
|
|
db = sqlite_utils.Database(memory=True)
|
|
migrate(db)
|
|
chain_response.log_to_db(db)
|
|
assert set(db.table_names()).issuperset(
|
|
{"tools", "tool_responses", "tool_calls", "tool_results"}
|
|
)
|
|
|
|
responses = list(db["responses"].rows)
|
|
assert len(responses) == 2
|
|
first_response, second_response = responses
|
|
|
|
tools = list(db["tools"].rows)
|
|
assert len(tools) == 1
|
|
assert tools[0]["name"] == "multiply"
|
|
assert tools[0]["description"] == "Multiply two numbers."
|
|
|
|
tool_results = list(db["tool_results"].rows)
|
|
tool_calls = list(db["tool_calls"].rows)
|
|
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0]["response_id"] == first_response["id"]
|
|
assert tool_calls[0]["name"] == "multiply"
|
|
assert tool_calls[0]["arguments"] == '{"a": 1231, "b": 2331}'
|
|
|
|
assert len(tool_results) == 1
|
|
assert tool_results[0]["response_id"] == second_response["id"]
|
|
assert tool_results[0]["output"] == "2869461"
|
|
assert tool_results[0]["tool_call_id"] == tool_calls[0]["tool_call_id"]
|
|
|
|
|
|
@pytest.mark.vcr
|
|
def test_tool_use_chain_of_two_calls(vcr):
|
|
model = llm.get_model("gpt-4o-mini")
|
|
|
|
def lookup_population(country: str) -> int:
|
|
"Returns the current population of the specified fictional country"
|
|
return 123124
|
|
|
|
def can_have_dragons(population: int) -> bool:
|
|
"Returns True if the specified population can have dragons, False otherwise"
|
|
return population > 10000
|
|
|
|
chain_response = model.chain(
|
|
"Can the country of Crumpet have dragons? Answer with only YES or NO",
|
|
tools=[lookup_population, can_have_dragons],
|
|
stream=False,
|
|
key=API_KEY,
|
|
)
|
|
|
|
output = chain_response.text()
|
|
assert output == "YES"
|
|
assert len(chain_response._responses) == 3
|
|
|
|
first, second, third = chain_response._responses
|
|
assert first.tool_calls()[0].arguments == {"country": "Crumpet"}
|
|
assert first.prompt.tool_results == []
|
|
assert second.prompt.tool_results[0].output == "123124"
|
|
assert second.tool_calls()[0].arguments == {"population": 123124}
|
|
assert third.prompt.tool_results[0].output == "true"
|
|
assert third.tool_calls() == []
|