From c81f0560e0a55adbfc4a2a44b2502283e33a05af Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 13 May 2025 16:47:59 -0700 Subject: [PATCH] Fixed remaining mypy problems, refs #1023 Refs https://github.com/simonw/llm/pull/996#issuecomment-2878191352 --- llm/__init__.py | 28 ++++++++++++------ llm/cli.py | 10 +++---- llm/models.py | 76 ++++++++++++++++++++++++------------------------- 3 files changed, 62 insertions(+), 52 deletions(-) diff --git a/llm/__init__.py b/llm/__init__.py index 5f28274..08849f8 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -27,7 +27,7 @@ from .embeddings import Collection from .templates import Template from .plugins import pm, load_plugins import click -from typing import Any, Dict, List, Optional, Callable, Union +from typing import Any, Dict, List, Optional, Callable, Union, cast import json import os import pathlib @@ -134,20 +134,30 @@ def get_fragment_loaders() -> Dict[ return _get_loaders(pm.hook.register_fragment_loaders) -def get_tools() -> Dict[str, List[Tool]]: +def get_tools() -> Dict[str, Tool]: """Get tools registered by plugins.""" load_plugins() - tools = {} + tools: Dict[str, Tool] = {} - def register(tool_or_function: Callable[..., Any], name: Optional[str] = None): - suffix = 0 + def register( + tool_or_function: Union[Tool, Callable[..., Any]], + name: Optional[str] = None, + ) -> None: + # If they handed us a bare function, wrap it in a Tool if not isinstance(tool_or_function, Tool): tool_or_function = Tool.function(tool_or_function) - prefix_to_try = tool_or_function.name - while prefix_to_try in tools: + + tool = cast(Tool, tool_or_function) + prefix = tool.name + suffix = 0 + candidate = prefix + + # avoid name collisions + while candidate in tools: suffix += 1 - prefix_to_try = f"{prefix}_{suffix}" - tools[prefix_to_try] = tool_or_function + candidate = f"{prefix}_{suffix}" + + tools[candidate] = tool pm.hook.register_tools(register=register) return tools diff --git a/llm/cli.py b/llm/cli.py index 856439a..4a63201 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -73,7 +73,7 @@ import sqlite_utils from sqlite_utils.utils import rows_from_file, Format import sys import textwrap -from typing import cast, Optional, Iterable, List, Union, Tuple, Any +from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Any import warnings import yaml @@ -800,7 +800,7 @@ def prompt( kwargs["before_call"] = approve_tool_call # Look up all those tools - registered_tools: dict = get_tools() + registered_tools = get_tools() bad_tools = [tool for tool in tools if tool not in registered_tools] if bad_tools: raise click.ClickException( @@ -3580,14 +3580,14 @@ def _tools_from_code(code_or_path: str) -> List[Tool]: code_or_path = pathlib.Path(code_or_path).read_text() except FileNotFoundError: raise click.ClickException("File not found: {}".format(code_or_path)) - globals = {} + namespace: Dict[str, Any] = {} tools = [] try: - exec(code_or_path, globals) + exec(code_or_path, namespace) except SyntaxError as ex: raise click.ClickException("Error in --functions definition: {}".format(ex)) # Register all callables in the locals dict: - for name, value in globals.items(): + for name, value in namespace.items(): if callable(value) and not name.startswith("_"): tools.append(Tool.function(value)) return tools diff --git a/llm/models.py b/llm/models.py index ea496ed..34fa775 100644 --- a/llm/models.py +++ b/llm/models.py @@ -490,42 +490,6 @@ class _BaseResponse: def add_tool_call(self, tool_call: ToolCall): self._tool_calls.append(tool_call) - def execute_tool_calls( - self, - *, - before_call: Optional[Callable[[Tool, ToolCall], None]] = None, - after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, - ) -> List[ToolResult]: - tool_results = [] - tools_by_name = {tool.name: tool for tool in self.prompt.tools} - # TODO: make this work async - for tool_call in self.tool_calls(): - tool = tools_by_name.get(tool_call.name) - if tool is None: - raise CancelToolCall("Unknown tool: {}".format(tool_call.name)) - if before_call: - # This may raise CancelToolCall: - before_call(tool, tool_call) - if not tool.implementation: - raise CancelToolCall( - "No implementation available for tool: {}".format(tool_call.name) - ) - try: - result = tool.implementation(**tool_call.arguments) - if not isinstance(result, str): - result = json.dumps(result, default=repr) - except Exception as ex: - result = f"Error: {ex}" - tool_result = ToolResult( - name=tool_call.name, - output=result, - tool_call_id=tool_call.tool_call_id, - ) - if after_call: - after_call(tool, tool_call, tool_result) - tool_results.append(tool_result) - return tool_results - def set_usage( self, *, @@ -815,11 +779,47 @@ class Response(_BaseResponse): def text_or_raise(self) -> str: return self.text() + def execute_tool_calls( + self, + *, + before_call: Optional[Callable[[Tool, ToolCall], None]] = None, + after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None, + ) -> List[ToolResult]: + tool_results = [] + tools_by_name = {tool.name: tool for tool in self.prompt.tools} + # TODO: make this work async + for tool_call in self.tool_calls(): + tool = tools_by_name.get(tool_call.name) + if tool is None: + raise CancelToolCall("Unknown tool: {}".format(tool_call.name)) + if before_call: + # This may raise CancelToolCall: + before_call(tool, tool_call) + if not tool.implementation: + raise CancelToolCall( + "No implementation available for tool: {}".format(tool_call.name) + ) + try: + result = tool.implementation(**tool_call.arguments) + if not isinstance(result, str): + result = json.dumps(result, default=repr) + except Exception as ex: + result = f"Error: {ex}" + tool_result = ToolResult( + name=tool_call.name, + output=result, + tool_call_id=tool_call.tool_call_id, + ) + if after_call: + after_call(tool, tool_call, tool_result) + tool_results.append(tool_result) + return tool_results + def tool_calls(self) -> List[ToolCall]: self._force() return self._tool_calls - def tool_calls_or_raise(self) -> str: + def tool_calls_or_raise(self) -> List[ToolCall]: return self.tool_calls() def json(self) -> Optional[Dict[str, Any]]: @@ -1122,7 +1122,7 @@ class _BaseChainResponse: conversation=self.conversation, ) else: - response = None + break def __iter__(self) -> Iterator[str]: for response in self.responses():