mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
chain_limit/before_call/after_call for conversations
* chain_limit/before_call/after_call for conversations, closes #1088 * Docs for before_call/after_call including for model.conversation
This commit is contained in:
parent
b5d1c5ee90
commit
ed64fc3362
4 changed files with 155 additions and 51 deletions
|
|
@ -148,6 +148,46 @@ for response in chain.responses():
|
|||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
(python-api-tools-debug-hooks)=
|
||||
|
||||
#### Tool debugging hooks
|
||||
|
||||
Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call.
|
||||
|
||||
The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example:
|
||||
```python
|
||||
import llm
|
||||
|
||||
def upper(text: str) -> str:
|
||||
"Convert text to uppercase."
|
||||
return text.upper()
|
||||
|
||||
def before_call(tool: llm.Tool, tool_call: llm.ToolCall):
|
||||
print(f"About to call tool {tool.name} with arguments {tool_call.arguments}")
|
||||
if tool.name == "upper" and "bad" in repr(tool_call.arguments):
|
||||
raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'")
|
||||
|
||||
model = llm.get_model("gpt-4.1-mini")
|
||||
response = model.chain(
|
||||
"Convert panda to upper and badger to upper",
|
||||
tools=[upper],
|
||||
before_call=before_call,
|
||||
)
|
||||
print(response.text())
|
||||
```
|
||||
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
|
||||
```python
|
||||
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):
|
||||
print(f"Tool {tool.name} called with arguments {tool_call.arguments} returned {tool_result.output}")
|
||||
|
||||
response = model.chain(
|
||||
"Convert panda to upper and badger to upper",
|
||||
tools=[upper],
|
||||
after_call=after_call,
|
||||
)
|
||||
print(response.text())
|
||||
```
|
||||
|
||||
(python-api-tools-attachments)=
|
||||
|
||||
#### Tools can return attachments
|
||||
|
|
@ -575,6 +615,8 @@ print(conversation.chain(
|
|||
"Same with pangolin"
|
||||
).text())
|
||||
```
|
||||
The `before_call=` and `after_call=` parameters {ref}`described above <python-api-tools-debug-hooks>` can be passed directly to the `model.conversation()` method to set those options for all chained prompts in that conversation.
|
||||
|
||||
|
||||
(python-api-listing-models)=
|
||||
|
||||
|
|
|
|||
|
|
@ -1080,6 +1080,11 @@ def chat(
|
|||
# Ensure it can see the API key
|
||||
conversation.model = model
|
||||
|
||||
if tools_debug:
|
||||
conversation.after_call = _debug_tool_call
|
||||
if tools_approve:
|
||||
conversation.before_call = _approve_tool_call
|
||||
|
||||
# Validate options
|
||||
validated_options = get_model_options(model.model_id)
|
||||
if options:
|
||||
|
|
@ -1100,10 +1105,6 @@ def chat(
|
|||
|
||||
if tool_functions:
|
||||
kwargs["chain_limit"] = chain_limit
|
||||
if tools_debug:
|
||||
kwargs["after_call"] = _debug_tool_call
|
||||
if tools_approve:
|
||||
kwargs["before_call"] = _approve_tool_call
|
||||
kwargs["tools"] = tool_functions
|
||||
|
||||
should_stream = model.can_stream and not no_stream
|
||||
|
|
|
|||
113
llm/models.py
113
llm/models.py
|
|
@ -259,9 +259,6 @@ class Toolbox:
|
|||
return methods
|
||||
|
||||
|
||||
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
name: str
|
||||
|
|
@ -286,6 +283,13 @@ class ToolOutput:
|
|||
attachments: List[Attachment] = field(default_factory=list)
|
||||
|
||||
|
||||
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
|
||||
BeforeCallSync = Callable[[Tool, ToolCall], None]
|
||||
AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None]
|
||||
BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
|
||||
|
||||
class CancelToolCall(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -368,6 +372,7 @@ class _BaseConversation:
|
|||
name: Optional[str] = None
|
||||
responses: List["_BaseResponse"] = field(default_factory=list)
|
||||
tools: Optional[List[Tool]] = None
|
||||
chain_limit: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
|
@ -377,6 +382,9 @@ class _BaseConversation:
|
|||
|
||||
@dataclass
|
||||
class Conversation(_BaseConversation):
|
||||
before_call: Optional[BeforeCallSync] = None
|
||||
after_call: Optional[AfterCallSync] = None
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
|
|
@ -424,8 +432,8 @@ class Conversation(_BaseConversation):
|
|||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
chain_limit: Optional[int] = None,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
before_call: Optional[BeforeCallSync] = None,
|
||||
after_call: Optional[AfterCallSync] = None,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> "ChainResponse":
|
||||
|
|
@ -447,9 +455,9 @@ class Conversation(_BaseConversation):
|
|||
stream=stream,
|
||||
conversation=self,
|
||||
key=key,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
chain_limit=chain_limit,
|
||||
before_call=before_call or self.before_call,
|
||||
after_call=after_call or self.after_call,
|
||||
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -470,6 +478,9 @@ class Conversation(_BaseConversation):
|
|||
|
||||
@dataclass
|
||||
class AsyncConversation(_BaseConversation):
|
||||
before_call: Optional[BeforeCallAsync] = None
|
||||
after_call: Optional[AfterCallAsync] = None
|
||||
|
||||
def chain(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
|
|
@ -483,12 +494,8 @@ class AsyncConversation(_BaseConversation):
|
|||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
chain_limit: Optional[int] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
before_call: Optional[BeforeCallAsync] = None,
|
||||
after_call: Optional[AfterCallAsync] = None,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> "AsyncChainResponse":
|
||||
|
|
@ -510,9 +517,9 @@ class AsyncConversation(_BaseConversation):
|
|||
stream=stream,
|
||||
conversation=self,
|
||||
key=key,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
chain_limit=chain_limit,
|
||||
before_call=before_call or self.before_call,
|
||||
after_call=after_call or self.after_call,
|
||||
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
|
||||
)
|
||||
|
||||
def prompt(
|
||||
|
|
@ -975,12 +982,8 @@ class Response(_BaseResponse):
|
|||
def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
before_call: Optional[BeforeCallSync] = None,
|
||||
after_call: Optional[AfterCallSync] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_results = []
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
|
|
@ -1147,12 +1150,8 @@ class AsyncResponse(_BaseResponse):
|
|||
async def execute_tool_calls(
|
||||
self,
|
||||
*,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
before_call: Optional[BeforeCallAsync] = None,
|
||||
after_call: Optional[AfterCallAsync] = None,
|
||||
) -> List[ToolResult]:
|
||||
tool_calls_list = await self.tool_calls()
|
||||
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
|
||||
|
|
@ -1437,12 +1436,8 @@ class _BaseChainResponse:
|
|||
conversation: _BaseConversation,
|
||||
key: Optional[str] = None,
|
||||
chain_limit: Optional[int] = 10,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
before_call: Optional[Union[BeforeCallSync, BeforeCallAsync]] = None,
|
||||
after_call: Optional[Union[AfterCallSync, AfterCallAsync]] = None,
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.model = model
|
||||
|
|
@ -1467,6 +1462,8 @@ class _BaseChainResponse:
|
|||
|
||||
class ChainResponse(_BaseChainResponse):
|
||||
_responses: List["Response"]
|
||||
before_call: Optional[BeforeCallSync] = None
|
||||
after_call: Optional[AfterCallSync] = None
|
||||
|
||||
def responses(self) -> Iterator[Response]:
|
||||
prompt = self.prompt
|
||||
|
|
@ -1521,6 +1518,8 @@ class ChainResponse(_BaseChainResponse):
|
|||
|
||||
class AsyncChainResponse(_BaseChainResponse):
|
||||
_responses: List["AsyncResponse"]
|
||||
before_call: Optional[BeforeCallAsync] = None
|
||||
after_call: Optional[AfterCallAsync] = None
|
||||
|
||||
async def responses(self) -> AsyncIterator[AsyncResponse]:
|
||||
prompt = self.prompt
|
||||
|
|
@ -1656,8 +1655,20 @@ class _BaseModel(ABC, _get_key_mixin):
|
|||
|
||||
|
||||
class _Model(_BaseModel):
|
||||
def conversation(self, tools: Optional[List[Tool]] = None) -> Conversation:
|
||||
return Conversation(model=self, tools=tools)
|
||||
def conversation(
|
||||
self,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
before_call: Optional[BeforeCallSync] = None,
|
||||
after_call: Optional[AfterCallSync] = None,
|
||||
chain_limit: Optional[int] = None,
|
||||
) -> Conversation:
|
||||
return Conversation(
|
||||
model=self,
|
||||
tools=tools,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
chain_limit=chain_limit,
|
||||
)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
|
|
@ -1705,8 +1716,8 @@ class _Model(_BaseModel):
|
|||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
|
||||
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
|
||||
before_call: Optional[BeforeCallSync] = None,
|
||||
after_call: Optional[AfterCallSync] = None,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> ChainResponse:
|
||||
|
|
@ -1753,8 +1764,20 @@ class KeyModel(_Model):
|
|||
|
||||
|
||||
class _AsyncModel(_BaseModel):
|
||||
def conversation(self, tools: Optional[List[Tool]] = None) -> AsyncConversation:
|
||||
return AsyncConversation(model=self, tools=tools)
|
||||
def conversation(
|
||||
self,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
before_call: Optional[BeforeCallAsync] = None,
|
||||
after_call: Optional[AfterCallAsync] = None,
|
||||
chain_limit: Optional[int] = None,
|
||||
) -> AsyncConversation:
|
||||
return AsyncConversation(
|
||||
model=self,
|
||||
tools=tools,
|
||||
before_call=before_call,
|
||||
after_call=after_call,
|
||||
chain_limit=chain_limit,
|
||||
)
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
|
|
@ -1802,12 +1825,8 @@ class _AsyncModel(_BaseModel):
|
|||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_results: Optional[List[ToolResult]] = None,
|
||||
before_call: Optional[
|
||||
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
after_call: Optional[
|
||||
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
|
||||
] = None,
|
||||
before_call: Optional[BeforeCallAsync] = None,
|
||||
after_call: Optional[AfterCallAsync] = None,
|
||||
key: Optional[str] = None,
|
||||
options: Optional[dict] = None,
|
||||
) -> AsyncChainResponse:
|
||||
|
|
|
|||
|
|
@ -343,3 +343,45 @@ async def test_async_tool_returning_attachment():
|
|||
output = await chain_response.text()
|
||||
assert '"type": "image/png"' in output
|
||||
assert '"output": "Output"' in output
|
||||
|
||||
|
||||
def test_tool_conversation_settings():
|
||||
model = llm.get_model("echo")
|
||||
before_collected = []
|
||||
after_collected = []
|
||||
|
||||
def before(*args):
|
||||
before_collected.append(args)
|
||||
|
||||
def after(*args):
|
||||
after_collected.append(args)
|
||||
|
||||
conversation = model.conversation(
|
||||
tools=[llm_time], before_call=before, after_call=after
|
||||
)
|
||||
# Run two things
|
||||
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
|
||||
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
|
||||
assert len(before_collected) == 2
|
||||
assert len(after_collected) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_conversation_settings_async():
|
||||
model = llm.get_async_model("echo")
|
||||
before_collected = []
|
||||
after_collected = []
|
||||
|
||||
async def before(*args):
|
||||
before_collected.append(args)
|
||||
|
||||
async def after(*args):
|
||||
after_collected.append(args)
|
||||
|
||||
conversation = model.conversation(
|
||||
tools=[llm_time], before_call=before, after_call=after
|
||||
)
|
||||
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
|
||||
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
|
||||
assert len(before_collected) == 2
|
||||
assert len(after_collected) == 2
|
||||
|
|
|
|||
Loading…
Reference in a new issue