Fixed remaining mypy problems, refs #1023

Refs https://github.com/simonw/llm/pull/996#issuecomment-2878191352
This commit is contained in:
Simon Willison 2025-05-13 16:47:59 -07:00
parent 96f910bb30
commit c81f0560e0
3 changed files with 62 additions and 52 deletions

View file

@ -27,7 +27,7 @@ from .embeddings import Collection
from .templates import Template
from .plugins import pm, load_plugins
import click
from typing import Any, Dict, List, Optional, Callable, Union
from typing import Any, Dict, List, Optional, Callable, Union, cast
import json
import os
import pathlib
@ -134,20 +134,30 @@ def get_fragment_loaders() -> Dict[
return _get_loaders(pm.hook.register_fragment_loaders)
def get_tools() -> Dict[str, List[Tool]]:
def get_tools() -> Dict[str, Tool]:
"""Get tools registered by plugins."""
load_plugins()
tools = {}
tools: Dict[str, Tool] = {}
def register(tool_or_function: Callable[..., Any], name: Optional[str] = None):
suffix = 0
def register(
tool_or_function: Union[Tool, Callable[..., Any]],
name: Optional[str] = None,
) -> None:
# If they handed us a bare function, wrap it in a Tool
if not isinstance(tool_or_function, Tool):
tool_or_function = Tool.function(tool_or_function)
prefix_to_try = tool_or_function.name
while prefix_to_try in tools:
tool = cast(Tool, tool_or_function)
prefix = tool.name
suffix = 0
candidate = prefix
# avoid name collisions
while candidate in tools:
suffix += 1
prefix_to_try = f"{prefix}_{suffix}"
tools[prefix_to_try] = tool_or_function
candidate = f"{prefix}_{suffix}"
tools[candidate] = tool
pm.hook.register_tools(register=register)
return tools

View file

@ -73,7 +73,7 @@ import sqlite_utils
from sqlite_utils.utils import rows_from_file, Format
import sys
import textwrap
from typing import cast, Optional, Iterable, List, Union, Tuple, Any
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Any
import warnings
import yaml
@ -800,7 +800,7 @@ def prompt(
kwargs["before_call"] = approve_tool_call
# Look up all those tools
registered_tools: dict = get_tools()
registered_tools = get_tools()
bad_tools = [tool for tool in tools if tool not in registered_tools]
if bad_tools:
raise click.ClickException(
@ -3580,14 +3580,14 @@ def _tools_from_code(code_or_path: str) -> List[Tool]:
code_or_path = pathlib.Path(code_or_path).read_text()
except FileNotFoundError:
raise click.ClickException("File not found: {}".format(code_or_path))
globals = {}
namespace: Dict[str, Any] = {}
tools = []
try:
exec(code_or_path, globals)
exec(code_or_path, namespace)
except SyntaxError as ex:
raise click.ClickException("Error in --functions definition: {}".format(ex))
# Register all callables in the locals dict:
for name, value in globals.items():
for name, value in namespace.items():
if callable(value) and not name.startswith("_"):
tools.append(Tool.function(value))
return tools

View file

@ -490,42 +490,6 @@ class _BaseResponse:
def add_tool_call(self, tool_call: ToolCall):
self._tool_calls.append(tool_call)
def execute_tool_calls(
self,
*,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# TODO: make this work async
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
if tool is None:
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
if before_call:
# This may raise CancelToolCall:
before_call(tool, tool_call)
if not tool.implementation:
raise CancelToolCall(
"No implementation available for tool: {}".format(tool_call.name)
)
try:
result = tool.implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
result = f"Error: {ex}"
tool_result = ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
if after_call:
after_call(tool, tool_call, tool_result)
tool_results.append(tool_result)
return tool_results
def set_usage(
self,
*,
@ -815,11 +779,47 @@ class Response(_BaseResponse):
def text_or_raise(self) -> str:
return self.text()
def execute_tool_calls(
self,
*,
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
) -> List[ToolResult]:
tool_results = []
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
# TODO: make this work async
for tool_call in self.tool_calls():
tool = tools_by_name.get(tool_call.name)
if tool is None:
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
if before_call:
# This may raise CancelToolCall:
before_call(tool, tool_call)
if not tool.implementation:
raise CancelToolCall(
"No implementation available for tool: {}".format(tool_call.name)
)
try:
result = tool.implementation(**tool_call.arguments)
if not isinstance(result, str):
result = json.dumps(result, default=repr)
except Exception as ex:
result = f"Error: {ex}"
tool_result = ToolResult(
name=tool_call.name,
output=result,
tool_call_id=tool_call.tool_call_id,
)
if after_call:
after_call(tool, tool_call, tool_result)
tool_results.append(tool_result)
return tool_results
def tool_calls(self) -> List[ToolCall]:
self._force()
return self._tool_calls
def tool_calls_or_raise(self) -> str:
def tool_calls_or_raise(self) -> List[ToolCall]:
return self.tool_calls()
def json(self) -> Optional[Dict[str, Any]]:
@ -1122,7 +1122,7 @@ class _BaseChainResponse:
conversation=self.conversation,
)
else:
response = None
break
def __iter__(self) -> Iterator[str]:
for response in self.responses():