mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Fixed remaining mypy problems, refs #1023
Refs https://github.com/simonw/llm/pull/996#issuecomment-2878191352
This commit is contained in:
parent
96f910bb30
commit
c81f0560e0
3 changed files with 62 additions and 52 deletions
|
|
@ -27,7 +27,7 @@ from .embeddings import Collection
|
|||
from .templates import Template
|
||||
from .plugins import pm, load_plugins
|
||||
import click
|
||||
from typing import Any, Dict, List, Optional, Callable, Union
|
||||
from typing import Any, Dict, List, Optional, Callable, Union, cast
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
|
@ -134,20 +134,30 @@ def get_fragment_loaders() -> Dict[
|
|||
return _get_loaders(pm.hook.register_fragment_loaders)
|
||||
|
||||
|
||||
def get_tools() -> Dict[str, List[Tool]]:
|
||||
def get_tools() -> Dict[str, Tool]:
|
||||
"""Get tools registered by plugins."""
|
||||
load_plugins()
|
||||
tools = {}
|
||||
tools: Dict[str, Tool] = {}
|
||||
|
||||
def register(tool_or_function: Callable[..., Any], name: Optional[str] = None):
|
||||
suffix = 0
|
||||
def register(
|
||||
tool_or_function: Union[Tool, Callable[..., Any]],
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
# If they handed us a bare function, wrap it in a Tool
|
||||
if not isinstance(tool_or_function, Tool):
|
||||
tool_or_function = Tool.function(tool_or_function)
|
||||
prefix_to_try = tool_or_function.name
|
||||
while prefix_to_try in tools:
|
||||
|
||||
tool = cast(Tool, tool_or_function)
|
||||
prefix = tool.name
|
||||
suffix = 0
|
||||
candidate = prefix
|
||||
|
||||
# avoid name collisions
|
||||
while candidate in tools:
|
||||
suffix += 1
|
||||
prefix_to_try = f"{prefix}_{suffix}"
|
||||
tools[prefix_to_try] = tool_or_function
|
||||
candidate = f"{prefix}_{suffix}"
|
||||
|
||||
tools[candidate] = tool
|
||||
|
||||
pm.hook.register_tools(register=register)
|
||||
return tools
|
||||
|
|
|
|||
10
llm/cli.py
10
llm/cli.py
|
|
@ -73,7 +73,7 @@ import sqlite_utils
|
|||
from sqlite_utils.utils import rows_from_file, Format
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional, Iterable, List, Union, Tuple, Any
|
||||
from typing import cast, Dict, Optional, Iterable, List, Union, Tuple, Any
|
||||
import warnings
|
||||
import yaml
|
||||
|
||||
|
|
@ -800,7 +800,7 @@ def prompt(
|
|||
|
||||
kwargs["before_call"] = approve_tool_call
|
||||
# Look up all those tools
|
||||
registered_tools: dict = get_tools()
|
||||
registered_tools = get_tools()
|
||||
bad_tools = [tool for tool in tools if tool not in registered_tools]
|
||||
if bad_tools:
|
||||
raise click.ClickException(
|
||||
|
|
@ -3580,14 +3580,14 @@ def _tools_from_code(code_or_path: str) -> List[Tool]:
|
|||
code_or_path = pathlib.Path(code_or_path).read_text()
|
||||
except FileNotFoundError:
|
||||
raise click.ClickException("File not found: {}".format(code_or_path))
|
||||
globals = {}
|
||||
namespace: Dict[str, Any] = {}
|
||||
tools = []
|
||||
try:
|
||||
exec(code_or_path, globals)
|
||||
exec(code_or_path, namespace)
|
||||
except SyntaxError as ex:
|
||||
raise click.ClickException("Error in --functions definition: {}".format(ex))
|
||||
# Register all callables in the locals dict:
|
||||
for name, value in globals.items():
|
||||
for name, value in namespace.items():
|
||||
if callable(value) and not name.startswith("_"):
|
||||
tools.append(Tool.function(value))
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -490,42 +490,6 @@ class _BaseResponse:
|
|||
def add_tool_call(self, tool_call: ToolCall):
|
||||
self._tool_calls.append(tool_call)
|
||||
|
||||
def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_results = []
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
# TODO: make this work async
|
||||
for tool_call in self.tool_calls():
|
||||
tool = tools_by_name.get(tool_call.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
|
||||
if before_call:
|
||||
# This may raise CancelToolCall:
|
||||
before_call(tool, tool_call)
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(
|
||||
"No implementation available for tool: {}".format(tool_call.name)
|
||||
)
|
||||
try:
|
||||
result = tool.implementation(**tool_call.arguments)
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
result = f"Error: {ex}"
|
||||
tool_result = ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
if after_call:
|
||||
after_call(tool, tool_call, tool_result)
|
||||
tool_results.append(tool_result)
|
||||
return tool_results
|
||||
|
||||
def set_usage(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -815,11 +779,47 @@ class Response(_BaseResponse):
|
|||
def text_or_raise(self) -> str:
|
||||
return self.text()
|
||||
|
||||
def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_results = []
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
# TODO: make this work async
|
||||
for tool_call in self.tool_calls():
|
||||
tool = tools_by_name.get(tool_call.name)
|
||||
if tool is None:
|
||||
raise CancelToolCall("Unknown tool: {}".format(tool_call.name))
|
||||
if before_call:
|
||||
# This may raise CancelToolCall:
|
||||
before_call(tool, tool_call)
|
||||
if not tool.implementation:
|
||||
raise CancelToolCall(
|
||||
"No implementation available for tool: {}".format(tool_call.name)
|
||||
)
|
||||
try:
|
||||
result = tool.implementation(**tool_call.arguments)
|
||||
if not isinstance(result, str):
|
||||
result = json.dumps(result, default=repr)
|
||||
except Exception as ex:
|
||||
result = f"Error: {ex}"
|
||||
tool_result = ToolResult(
|
||||
name=tool_call.name,
|
||||
output=result,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
)
|
||||
if after_call:
|
||||
after_call(tool, tool_call, tool_result)
|
||||
tool_results.append(tool_result)
|
||||
return tool_results
|
||||
|
||||
def tool_calls(self) -> List[ToolCall]:
|
||||
self._force()
|
||||
return self._tool_calls
|
||||
|
||||
def tool_calls_or_raise(self) -> str:
|
||||
def tool_calls_or_raise(self) -> List[ToolCall]:
|
||||
return self.tool_calls()
|
||||
|
||||
def json(self) -> Optional[Dict[str, Any]]:
|
||||
|
|
@ -1122,7 +1122,7 @@ class _BaseChainResponse:
|
|||
conversation=self.conversation,
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
break
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for response in self.responses():
|
||||
|
|
|
|||
Loading…
Reference in a new issue