register_fragment_loaders() hook (#886)

* Docs and shape of register_fragment_loaders hook, refs #863
* Update docs for fragment loaders returning a list of FragmentString
* Support multiple fragments with same content, closes #888
* Call the pm.hook.register_fragment_loaders hook
* Test for register_fragment_loaders hook
* Rename FragmentString to Fragment

Closes #863
This commit is contained in:
Simon Willison 2025-04-06 17:03:34 -07:00 committed by GitHub
parent 3de33be74f
commit a571a4e948
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 161 additions and 31 deletions

View file

@ -106,4 +106,49 @@ def my_template_loader(template_path: str) -> llm.Template:
```
Consult the latest code in [llm/templates.py](https://github.com/simonw/llm/blob/main/llm/templates.py) for details of that `llm.Template` class.
The loader function should raise a `ValueError` if the template cannot be found or loaded correctly, providing a clear error message.
The loader function should raise a `ValueError` if the template cannot be found or loaded correctly, providing a clear error message.
(plugin-hooks-register-fragment-loaders)=
## register_fragment_loaders(register)
Plugins can register new fragment loaders using the `register_template_loaders` hook. These can then be used with the `llm -f prefix:argument` syntax.
The `prefix` specifies the loader. The `argument` will be passed to that registered callback..
The callback works in a very similar way to template loaders, but returns either a single `llm.Fragment` or a list of `llm.Fragment` objects.
The `llm.Fragment` constructor takes a required string argument (the content of the fragment) and an optional second `source` argument, which is a string that may be displayed as debug information. For files this is a path and for URLs it is a URL. Your plugin can use anything you like for the `source` value.
```python
import llm
@llm.hookimpl
def register_fragment_loaders(register):
register("my-fragments", my_fragment_loader)
def my_fragment_loader(argument: str) -> llm.Fragment:
try:
fragment = "Fragment content for {}".format(argument)
source = "my-fragments:{}".format(argument)
return llm.Fragment(fragment, source)
except Exception as ex:
# Raise a ValueError with a clear message if the fragment cannot be loaded
raise ValueError(
f"Fragment 'my-fragments:{argument}' could not be loaded: {str(ex)}"
)
# Or for the case where you want to return multiple fragments:
def my_fragment_loader(argument: str) -> list[llm.Fragment]:
return [
llm.Fragment("Fragment 1 content", "my-fragments:{argument}"),
llm.Fragment("Fragment 2 content", "my-fragments:{argument}"),
]
```
A plugin like this one can be called like so:
```bash
llm -f my-fragments:argument
```
If multiple fragments are returned they will be used as if the user passed multiple `-f X` arguments to the command.
Multiple fragments are useful for things like plugins that return every file in a directory. By giving each file its own fragment we can avoid having multiple copies of the full collection stored if only a single file has changed.

View file

@ -19,12 +19,12 @@ from .models import (
Prompt,
Response,
)
from .utils import schema_dsl
from .utils import schema_dsl, Fragment
from .embeddings import Collection
from .templates import Template
from .plugins import pm, load_plugins
import click
from typing import Dict, List, Optional, Callable
from typing import Dict, List, Optional, Callable, Union
import json
import os
import pathlib
@ -37,6 +37,7 @@ __all__ = [
"Attachment",
"Collection",
"Conversation",
"Fragment",
"get_async_model",
"get_key",
"get_model",
@ -98,7 +99,7 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
return model_aliases
def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
def _get_loaders(hook_method) -> Dict[str, Callable]:
load_plugins()
loaders = {}
@ -110,10 +111,22 @@ def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
prefix_to_try = f"{prefix}_{suffix}"
loaders[prefix_to_try] = loader
pm.hook.register_template_loaders(register=register)
hook_method(register=register)
return loaders
def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
"""Get template loaders registered by plugins."""
return _get_loaders(pm.hook.register_template_loaders)
def get_fragment_loaders() -> (
Dict[str, Callable[[str], Union[Fragment, List[Fragment]]]]
):
"""Get fragment loaders registered by plugins."""
return _get_loaders(pm.hook.register_fragment_loaders)
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
model_aliases = []

View file

@ -12,6 +12,7 @@ from llm import (
AsyncResponse,
Collection,
Conversation,
Fragment,
Response,
Template,
UnknownModelError,
@ -24,6 +25,7 @@ from llm import (
get_embedding_model_aliases,
get_embedding_model,
get_plugins,
get_fragment_loaders,
get_template_loaders,
get_model,
get_model_aliases,
@ -42,7 +44,7 @@ from .utils import (
ensure_fragment,
extract_fenced_code_block,
find_unused_key,
FragmentString,
has_plugin_prefix,
make_schema_id,
maybe_fenced_code,
mimetype_from_path,
@ -88,7 +90,7 @@ def validate_fragment_alias(ctx, param, value):
def resolve_fragments(
db: sqlite_utils.Database, fragments: Iterable[str]
) -> List[FragmentString]:
) -> List[Fragment]:
"""
Resolve fragments into a list of (content, source) tuples
"""
@ -109,28 +111,41 @@ def resolve_fragments(
return row["content"], row["source"]
return None, None
# These can be URLs or paths
# These can be URLs or paths or plugin references
resolved = []
for fragment in fragments:
if fragment.startswith("http://") or fragment.startswith("https://"):
client = httpx.Client(follow_redirects=True, max_redirects=3)
response = client.get(fragment)
response.raise_for_status()
resolved.append(FragmentString(response.text, fragment))
resolved.append(Fragment(response.text, fragment))
elif fragment == "-":
resolved.append(FragmentString(sys.stdin.read(), "-"))
resolved.append(Fragment(sys.stdin.read(), "-"))
elif has_plugin_prefix(fragment):
prefix, rest = fragment.split(":", 1)
loaders = get_fragment_loaders()
if prefix not in loaders:
raise FragmentNotFound("Unknown fragment prefix: {}".format(prefix))
loader = loaders[prefix]
try:
result = loader(rest)
if not isinstance(result, list):
result = [result]
resolved.extend(result)
except Exception as ex:
raise FragmentNotFound(
"Could not load fragment {}: {}".format(fragment, ex)
)
else:
# Try from the DB
content, source = _load_by_alias(fragment)
if content is not None:
resolved.append(FragmentString(content, source))
resolved.append(Fragment(content, source))
else:
# Now try path
path = pathlib.Path(fragment)
if path.exists():
resolved.append(
FragmentString(path.read_text(), str(path.resolve()))
)
resolved.append(Fragment(path.read_text(), str(path.resolve())))
else:
raise FragmentNotFound(f"Fragment '{fragment}' not found")
return resolved
@ -3113,7 +3128,7 @@ def load_template(name: str) -> Template:
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
return _parse_yaml_template(name, response.text)
if ":" in name:
if has_plugin_prefix(name):
prefix, rest = name.split(":", 1)
loaders = get_template_loaders()
if prefix not in loaders:

View file

@ -23,3 +23,8 @@ def register_embedding_models(register):
@hookspec
def register_template_loaders(register):
"Register additional template loaders with prefixes"
@hookspec
def register_fragment_loaders(register):
"Register additional fragment loaders with prefixes"

View file

@ -303,6 +303,7 @@ def m015_fragments_tables(db):
pk=("response_id", "fragment_id"),
)
@migration
def m016_fragments_table_pks(db):
# The same fragment can be attached to a response multiple times

View file

@ -14,21 +14,15 @@ MIME_TYPE_FIXES = {
}
class FragmentString(str):
def __new__(cls, content, source):
# We need to use __new__ since str is immutable
instance = super().__new__(cls, content)
return instance
class Fragment(str):
def __new__(cls, content, *args, **kwargs):
# For immutable classes like str, __new__ creates the string object
return super().__new__(cls, content)
def __init__(self, content, source):
def __init__(self, content, source=""):
# Initialize our custom attributes
self.source = source
def __str__(self):
return super().__str__()
def __repr__(self):
return super().__repr__()
def id(self):
return hashlib.sha256(self.encode("utf-8")).hexdigest()
@ -465,7 +459,7 @@ def ensure_fragment(db, content):
"""
hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
source = None
if isinstance(content, FragmentString):
if isinstance(content, Fragment):
source = content.source
with db.conn:
db.execute(sql, {"hash": hash, "content": content, "source": source})
@ -501,3 +495,11 @@ def maybe_fenced_code(content: str) -> str:
+ "`" * num_backticks
)
return content
_plugin_prefix_re = re.compile(r"^[a-zA-Z0-9_-]+:")
def has_plugin_prefix(value: str) -> bool:
"Check if value starts with alphanumeric prefix followed by a colon"
return bool(_plugin_prefix_re.match(value))

View file

@ -1,7 +1,7 @@
from click.testing import CliRunner
from llm.cli import cli
from llm.migrations import migrate
from llm.utils import FragmentString
from llm import Fragment
from ulid import ULID
import datetime
import json
@ -440,7 +440,7 @@ def fragments_fixture(user_path):
# Create fragments
for i in range(1, 6):
content = f"This is fragment {i}" * (100 if i == 5 else 1)
fragment = FragmentString(content, "fragment")
fragment = Fragment(content, "fragment")
db["fragments"].insert(
{
"id": i,

View file

@ -2,7 +2,7 @@ from click.testing import CliRunner
import click
import importlib
import llm
from llm import cli, hookimpl, plugins, get_template_loaders
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
def test_register_commands():
@ -88,3 +88,52 @@ def test_register_template_loaders():
finally:
plugins.pm.unregister(name="TemplateLoadersPlugin")
assert get_template_loaders() == {}
def test_register_fragment_loaders(logs_db):
assert get_fragment_loaders() == {}
def single_fragment(argument):
return llm.Fragment("single", "single")
def three_fragments(argument):
return [
llm.Fragment(f"one:{argument}", "one"),
llm.Fragment(f"two:{argument}", "two"),
llm.Fragment(f"three:{argument}", "three"),
]
class FragmentLoadersPlugin:
__name__ = "FragmentLoadersPlugin"
@hookimpl
def register_fragment_loaders(self, register):
register("single", single_fragment)
register("three", three_fragments)
try:
plugins.pm.register(FragmentLoadersPlugin(), name="FragmentLoadersPlugin")
loaders = get_fragment_loaders()
assert loaders == {
"single": single_fragment,
"three": three_fragments,
}
# Test the CLI command
runner = CliRunner()
result = runner.invoke(
cli.cli, ["-m", "echo", "-f", "three:x"], catch_exceptions=False
)
assert result.exit_code == 0
expected = "prompt:\n" "one:x\n" "two:x\n" "three:x\n"
assert expected in result.output
finally:
plugins.pm.unregister(name="FragmentLoadersPlugin")
assert get_fragment_loaders() == {}
# Let's check the database
assert list(logs_db.query("select content, source from fragments")) == [
{"content": "one:x", "source": "one"},
{"content": "two:x", "source": "two"},
{"content": "three:x", "source": "three"},
]