First working prototype of llm -T toolname

Refs https://github.com/simonw/llm/issues/990#issuecomment-2870630946
This commit is contained in:
Simon Willison 2025-05-11 20:09:31 -07:00
parent 290c0f13f3
commit 4ae6d45f3a
2 changed files with 37 additions and 3 deletions

View file

@ -25,6 +25,7 @@ from llm import (
get_embedding_model_aliases,
get_embedding_model,
get_plugins,
get_tools,
get_fragment_loaders,
get_template_loaders,
get_model,
@ -36,7 +37,7 @@ from llm import (
set_default_embedding_model,
remove_alias,
)
from llm.models import _BaseConversation
from llm.models import _BaseConversation, ChainResponse
from .migrations import migrate
from .plugins import pm, load_plugins
@ -331,6 +332,13 @@ def cli():
callback=attachment_types_callback,
help="\b\nAttachment with explicit mimetype,\n--at image.jpg image/jpeg",
)
@click.option(
"tools",
"-T",
"--tool",
multiple=True,
help="Name of a tool to make available to the model",
)
@click.option(
"options",
"-o",
@ -403,6 +411,7 @@ def prompt(
queries,
attachments,
attachment_types,
tools,
options,
schema_input,
schema_multi,
@ -659,6 +668,9 @@ def prompt(
except UnknownModelError as ex:
raise click.ClickException(ex)
if conversation is None and tools:
conversation = model.conversation()
if conversation:
# To ensure it can see the key
conversation.model = model
@ -718,6 +730,21 @@ def prompt(
if conversation:
prompt_method = conversation.prompt
if tools:
prompt_method = lambda *args, **kwargs: conversation.chain(
*args, **kwargs
).details()
# prompt_method = conversation.chain
# Look up all those tools
registered_tools: dict = get_tools()
bad_tools = [tool for tool in tools if tool not in registered_tools]
if bad_tools:
raise click.ClickException(
"Tool(s) {} not found. Available tools: {}".format(
", ".join(bad_tools), ", ".join(registered_tools.keys())
)
)
kwargs["tools"] = [registered_tools[tool] for tool in tools]
try:
if async_:
@ -786,6 +813,10 @@ def prompt(
raise
raise click.ClickException(str(ex))
if not isinstance(response, (Response, AsyncResponse)):
# Terminate early, logging for tool streaming mechanism not implemented yet
return
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())

View file

@ -313,6 +313,7 @@ class Conversation(_BaseConversation):
tools: Optional[List[Tool]] = None,
tool_results: Optional[List[ToolResult]] = None,
details: bool = False,
key: Optional[str] = None,
**options,
):
self.model._validate_attachments(attachments)
@ -332,7 +333,7 @@ class Conversation(_BaseConversation):
model=self.model,
stream=stream,
conversation=self,
key=options.pop("key", None),
key=key,
details=details,
)
@ -954,7 +955,9 @@ class _BaseChainResponse:
if details:
yield "\nTool call requested: {}({})\n".format(
tool_call.name,
", ".join(f"{k}={v}" for k, v in tool_call.arguments.items()),
", ".join(
f"{k}={repr(v)}" for k, v in tool_call.arguments.items()
),
)
implementation = tools_by_name.get(tool_calls[0].name)
if not implementation: