mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Log tool stuff to the database, refs #1003
This commit is contained in:
parent
4abd6e0faf
commit
32cab986ea
4 changed files with 150 additions and 22 deletions
40
llm/cli.py
40
llm/cli.py
|
|
@ -743,9 +743,7 @@ def prompt(
|
|||
extra_tools = _tools_from_code(python_tools)
|
||||
|
||||
if tools or python_tools:
|
||||
prompt_method = lambda *args, **kwargs: conversation.chain(
|
||||
*args, **kwargs
|
||||
).details()
|
||||
prompt_method = conversation.chain
|
||||
# prompt_method = conversation.chain
|
||||
# Look up all those tools
|
||||
registered_tools: dict = get_tools()
|
||||
|
|
@ -825,25 +823,29 @@ def prompt(
|
|||
raise
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
if not isinstance(response, (Response, AsyncResponse)):
|
||||
# Terminate early, logging for tool streaming mechanism not implemented yet
|
||||
return
|
||||
if isinstance(response, ChainResponse):
|
||||
responses = response._responses
|
||||
else:
|
||||
responses = [response]
|
||||
|
||||
if isinstance(response, AsyncResponse):
|
||||
response = asyncio.run(response.to_sync_response())
|
||||
for response in responses:
|
||||
if isinstance(response, AsyncResponse):
|
||||
response = asyncio.run(response.to_sync_response())
|
||||
|
||||
if usage:
|
||||
# Show token usage to stderr in yellow
|
||||
click.echo(
|
||||
click.style(
|
||||
"Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
if usage:
|
||||
# Show token usage to stderr in yellow
|
||||
click.echo(
|
||||
click.style(
|
||||
"Token usage: {}".format(response.token_usage()),
|
||||
fg="yellow",
|
||||
bold=True,
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
|
||||
# Log to the database
|
||||
if (logs_on() or log) and not no_log:
|
||||
response.log_to_db(db)
|
||||
# Log to the database
|
||||
if (logs_on() or log) and not no_log:
|
||||
response.log_to_db(db)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
|||
|
|
@ -310,3 +310,61 @@ def m016_fragments_table_pks(db):
|
|||
# https://github.com/simonw/llm/issues/863#issuecomment-2781720064
|
||||
db["prompt_fragments"].transform(pk=("response_id", "fragment_id", "order"))
|
||||
db["system_fragments"].transform(pk=("response_id", "fragment_id", "order"))
|
||||
|
||||
|
||||
@migration
|
||||
def m017_tools_tables(db):
|
||||
db["tools"].create(
|
||||
{
|
||||
"id": int,
|
||||
"hash": str,
|
||||
"name": str,
|
||||
"description": str,
|
||||
"input_schema": str,
|
||||
},
|
||||
pk="id",
|
||||
)
|
||||
db["tools"].create_index(["hash"], unique=True)
|
||||
# Many-to-many relationship between tools and responses
|
||||
db["tool_responses"].create(
|
||||
{
|
||||
"tool_id": int,
|
||||
"response_id": str,
|
||||
},
|
||||
foreign_keys=(
|
||||
("tool_id", "tools", "id"),
|
||||
("response_id", "responses", "id"),
|
||||
),
|
||||
pk=("tool_id", "response_id"),
|
||||
)
|
||||
# tool_calls and tool_results are one-to-many against responses
|
||||
db["tool_calls"].create(
|
||||
{
|
||||
"id": int,
|
||||
"response_id": str,
|
||||
"tool_id": int,
|
||||
"name": str,
|
||||
"arguments": str,
|
||||
"tool_call_id": str,
|
||||
},
|
||||
pk="id",
|
||||
foreign_keys=(
|
||||
("response_id", "responses", "id"),
|
||||
("tool_id", "tools", "id"),
|
||||
),
|
||||
)
|
||||
db["tool_results"].create(
|
||||
{
|
||||
"id": int,
|
||||
"response_id": str,
|
||||
"tool_id": int,
|
||||
"name": str,
|
||||
"output": str,
|
||||
"tool_call_id": str,
|
||||
},
|
||||
pk="id",
|
||||
foreign_keys=(
|
||||
("response_id", "responses", "id"),
|
||||
("tool_id", "tools", "id"),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from typing import (
|
|||
)
|
||||
from .utils import (
|
||||
ensure_fragment,
|
||||
ensure_tool,
|
||||
make_schema_id,
|
||||
mimetype_from_path,
|
||||
mimetype_from_string,
|
||||
|
|
@ -127,6 +128,15 @@ class Tool:
|
|||
return schema_dict
|
||||
return schema
|
||||
|
||||
def hash(self):
|
||||
"""Hash for tool based on its name, description and input schema (preserving key order)"""
|
||||
to_hash = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
return hashlib.sha256(json.dumps(to_hash).encode("utf-8")).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def function(cls, function, name=None):
|
||||
"""
|
||||
|
|
@ -597,6 +607,10 @@ class _BaseResponse:
|
|||
response_text = self.text_or_raise()
|
||||
replacements[f"r:{response_id}"] = response_text
|
||||
json_data = self.json()
|
||||
|
||||
# Temporary workraound TODO remove this
|
||||
json_data = {}
|
||||
|
||||
response = {
|
||||
"id": response_id,
|
||||
"model": self.model.model_id,
|
||||
|
|
@ -643,6 +657,38 @@ class _BaseResponse:
|
|||
},
|
||||
)
|
||||
|
||||
# Persist any tools, tool calls and tool results
|
||||
tool_ids_by_name = {}
|
||||
for tool in self.prompt.tools:
|
||||
tool_id = ensure_tool(db, tool)
|
||||
tool_ids_by_name[tool.name] = tool_id
|
||||
db["tool_responses"].insert(
|
||||
{
|
||||
"tool_id": tool_id,
|
||||
"response_id": response_id,
|
||||
}
|
||||
)
|
||||
for tool_call in self.tool_calls(): # TODO Should be _or_raise()
|
||||
db["tool_calls"].insert(
|
||||
{
|
||||
"response_id": response_id,
|
||||
"tool_id": tool_ids_by_name.get(tool_call.name) or None,
|
||||
"name": tool_call.name,
|
||||
"arguments": json.dumps(tool_call.arguments),
|
||||
"tool_call_id": tool_call.tool_call_id,
|
||||
}
|
||||
)
|
||||
for tool_result in self.prompt.tool_results:
|
||||
db["tool_results"].insert(
|
||||
{
|
||||
"response_id": response_id,
|
||||
"tool_id": tool_ids_by_name.get(tool_result.name) or None,
|
||||
"name": tool_result.name,
|
||||
"output": tool_result.output,
|
||||
"tool_call_id": tool_result.tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Response(_BaseResponse):
|
||||
model: "Model"
|
||||
|
|
@ -942,6 +988,7 @@ class _BaseChainResponse:
|
|||
while response:
|
||||
count += 1
|
||||
yield response
|
||||
self._responses.append(response)
|
||||
if count > self.chain_limit:
|
||||
raise ValueError(f"Chain limit of {self.chain_limit} exceeded. ")
|
||||
tool_calls = response.tool_calls()
|
||||
|
|
|
|||
27
llm/utils.py
27
llm/utils.py
|
|
@ -457,14 +457,35 @@ def ensure_fragment(db, content):
|
|||
values (:hash, :content, datetime('now'), :source)
|
||||
on conflict(hash) do nothing
|
||||
"""
|
||||
hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
hash_id = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
source = None
|
||||
if isinstance(content, Fragment):
|
||||
source = content.source
|
||||
with db.conn:
|
||||
db.execute(sql, {"hash": hash, "content": content, "source": source})
|
||||
db.execute(sql, {"hash": hash_id, "content": content, "source": source})
|
||||
return list(
|
||||
db.query("select id from fragments where hash = :hash", {"hash": hash})
|
||||
db.query("select id from fragments where hash = :hash", {"hash": hash_id})
|
||||
)[0]["id"]
|
||||
|
||||
|
||||
def ensure_tool(db, tool):
|
||||
sql = """
|
||||
insert into tools (hash, name, description, input_schema)
|
||||
values (:hash, :name, :description, :input_schema)
|
||||
on conflict(hash) do nothing
|
||||
"""
|
||||
with db.conn:
|
||||
db.execute(
|
||||
sql,
|
||||
{
|
||||
"hash": tool.hash(),
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": json.dumps(tool.input_schema),
|
||||
},
|
||||
)
|
||||
return list(
|
||||
db.query("select id from tools where hash = :hash", {"hash": tool.hash()})
|
||||
)[0]["id"]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue