mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-03 19:34:44 +00:00
register_fragment_loaders() hook (#886)
* Docs and shape of register_fragment_loaders hook, refs #863 * Update docs for fragment loaders returning a list of FragmentString * Support multiple fragments with same content, closes #888 * Call the pm.hook.register_fragment_loaders hook * Test for register_fragment_loaders hook * Rename FragmentString to Fragment Closes #863
This commit is contained in:
parent
3de33be74f
commit
a571a4e948
8 changed files with 161 additions and 31 deletions
|
|
@ -106,4 +106,49 @@ def my_template_loader(template_path: str) -> llm.Template:
|
|||
```
|
||||
Consult the latest code in [llm/templates.py](https://github.com/simonw/llm/blob/main/llm/templates.py) for details of that `llm.Template` class.
|
||||
|
||||
The loader function should raise a `ValueError` if the template cannot be found or loaded correctly, providing a clear error message.
|
||||
The loader function should raise a `ValueError` if the template cannot be found or loaded correctly, providing a clear error message.
|
||||
|
||||
(plugin-hooks-register-fragment-loaders)=
|
||||
## register_fragment_loaders(register)
|
||||
|
||||
Plugins can register new fragment loaders using the `register_template_loaders` hook. These can then be used with the `llm -f prefix:argument` syntax.
|
||||
|
||||
The `prefix` specifies the loader. The `argument` will be passed to that registered callback..
|
||||
|
||||
The callback works in a very similar way to template loaders, but returns either a single `llm.Fragment` or a list of `llm.Fragment` objects.
|
||||
|
||||
The `llm.Fragment` constructor takes a required string argument (the content of the fragment) and an optional second `source` argument, which is a string that may be displayed as debug information. For files this is a path and for URLs it is a URL. Your plugin can use anything you like for the `source` value.
|
||||
|
||||
```python
|
||||
import llm
|
||||
|
||||
@llm.hookimpl
|
||||
def register_fragment_loaders(register):
|
||||
register("my-fragments", my_fragment_loader)
|
||||
|
||||
|
||||
def my_fragment_loader(argument: str) -> llm.Fragment:
|
||||
try:
|
||||
fragment = "Fragment content for {}".format(argument)
|
||||
source = "my-fragments:{}".format(argument)
|
||||
return llm.Fragment(fragment, source)
|
||||
except Exception as ex:
|
||||
# Raise a ValueError with a clear message if the fragment cannot be loaded
|
||||
raise ValueError(
|
||||
f"Fragment 'my-fragments:{argument}' could not be loaded: {str(ex)}"
|
||||
)
|
||||
|
||||
# Or for the case where you want to return multiple fragments:
|
||||
def my_fragment_loader(argument: str) -> list[llm.Fragment]:
|
||||
return [
|
||||
llm.Fragment("Fragment 1 content", "my-fragments:{argument}"),
|
||||
llm.Fragment("Fragment 2 content", "my-fragments:{argument}"),
|
||||
]
|
||||
```
|
||||
A plugin like this one can be called like so:
|
||||
```bash
|
||||
llm -f my-fragments:argument
|
||||
```
|
||||
If multiple fragments are returned they will be used as if the user passed multiple `-f X` arguments to the command.
|
||||
|
||||
Multiple fragments are useful for things like plugins that return every file in a directory. By giving each file its own fragment we can avoid having multiple copies of the full collection stored if only a single file has changed.
|
||||
|
|
|
|||
|
|
@ -19,12 +19,12 @@ from .models import (
|
|||
Prompt,
|
||||
Response,
|
||||
)
|
||||
from .utils import schema_dsl
|
||||
from .utils import schema_dsl, Fragment
|
||||
from .embeddings import Collection
|
||||
from .templates import Template
|
||||
from .plugins import pm, load_plugins
|
||||
import click
|
||||
from typing import Dict, List, Optional, Callable
|
||||
from typing import Dict, List, Optional, Callable, Union
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
|
|
@ -37,6 +37,7 @@ __all__ = [
|
|||
"Attachment",
|
||||
"Collection",
|
||||
"Conversation",
|
||||
"Fragment",
|
||||
"get_async_model",
|
||||
"get_key",
|
||||
"get_model",
|
||||
|
|
@ -98,7 +99,7 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
|
|||
return model_aliases
|
||||
|
||||
|
||||
def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
|
||||
def _get_loaders(hook_method) -> Dict[str, Callable]:
|
||||
load_plugins()
|
||||
loaders = {}
|
||||
|
||||
|
|
@ -110,10 +111,22 @@ def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
|
|||
prefix_to_try = f"{prefix}_{suffix}"
|
||||
loaders[prefix_to_try] = loader
|
||||
|
||||
pm.hook.register_template_loaders(register=register)
|
||||
hook_method(register=register)
|
||||
return loaders
|
||||
|
||||
|
||||
def get_template_loaders() -> Dict[str, Callable[[str], Template]]:
|
||||
"""Get template loaders registered by plugins."""
|
||||
return _get_loaders(pm.hook.register_template_loaders)
|
||||
|
||||
|
||||
def get_fragment_loaders() -> (
|
||||
Dict[str, Callable[[str], Union[Fragment, List[Fragment]]]]
|
||||
):
|
||||
"""Get fragment loaders registered by plugins."""
|
||||
return _get_loaders(pm.hook.register_fragment_loaders)
|
||||
|
||||
|
||||
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
|
||||
model_aliases = []
|
||||
|
||||
|
|
|
|||
35
llm/cli.py
35
llm/cli.py
|
|
@ -12,6 +12,7 @@ from llm import (
|
|||
AsyncResponse,
|
||||
Collection,
|
||||
Conversation,
|
||||
Fragment,
|
||||
Response,
|
||||
Template,
|
||||
UnknownModelError,
|
||||
|
|
@ -24,6 +25,7 @@ from llm import (
|
|||
get_embedding_model_aliases,
|
||||
get_embedding_model,
|
||||
get_plugins,
|
||||
get_fragment_loaders,
|
||||
get_template_loaders,
|
||||
get_model,
|
||||
get_model_aliases,
|
||||
|
|
@ -42,7 +44,7 @@ from .utils import (
|
|||
ensure_fragment,
|
||||
extract_fenced_code_block,
|
||||
find_unused_key,
|
||||
FragmentString,
|
||||
has_plugin_prefix,
|
||||
make_schema_id,
|
||||
maybe_fenced_code,
|
||||
mimetype_from_path,
|
||||
|
|
@ -88,7 +90,7 @@ def validate_fragment_alias(ctx, param, value):
|
|||
|
||||
def resolve_fragments(
|
||||
db: sqlite_utils.Database, fragments: Iterable[str]
|
||||
) -> List[FragmentString]:
|
||||
) -> List[Fragment]:
|
||||
"""
|
||||
Resolve fragments into a list of (content, source) tuples
|
||||
"""
|
||||
|
|
@ -109,28 +111,41 @@ def resolve_fragments(
|
|||
return row["content"], row["source"]
|
||||
return None, None
|
||||
|
||||
# These can be URLs or paths
|
||||
# These can be URLs or paths or plugin references
|
||||
resolved = []
|
||||
for fragment in fragments:
|
||||
if fragment.startswith("http://") or fragment.startswith("https://"):
|
||||
client = httpx.Client(follow_redirects=True, max_redirects=3)
|
||||
response = client.get(fragment)
|
||||
response.raise_for_status()
|
||||
resolved.append(FragmentString(response.text, fragment))
|
||||
resolved.append(Fragment(response.text, fragment))
|
||||
elif fragment == "-":
|
||||
resolved.append(FragmentString(sys.stdin.read(), "-"))
|
||||
resolved.append(Fragment(sys.stdin.read(), "-"))
|
||||
elif has_plugin_prefix(fragment):
|
||||
prefix, rest = fragment.split(":", 1)
|
||||
loaders = get_fragment_loaders()
|
||||
if prefix not in loaders:
|
||||
raise FragmentNotFound("Unknown fragment prefix: {}".format(prefix))
|
||||
loader = loaders[prefix]
|
||||
try:
|
||||
result = loader(rest)
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
resolved.extend(result)
|
||||
except Exception as ex:
|
||||
raise FragmentNotFound(
|
||||
"Could not load fragment {}: {}".format(fragment, ex)
|
||||
)
|
||||
else:
|
||||
# Try from the DB
|
||||
content, source = _load_by_alias(fragment)
|
||||
if content is not None:
|
||||
resolved.append(FragmentString(content, source))
|
||||
resolved.append(Fragment(content, source))
|
||||
else:
|
||||
# Now try path
|
||||
path = pathlib.Path(fragment)
|
||||
if path.exists():
|
||||
resolved.append(
|
||||
FragmentString(path.read_text(), str(path.resolve()))
|
||||
)
|
||||
resolved.append(Fragment(path.read_text(), str(path.resolve())))
|
||||
else:
|
||||
raise FragmentNotFound(f"Fragment '{fragment}' not found")
|
||||
return resolved
|
||||
|
|
@ -3113,7 +3128,7 @@ def load_template(name: str) -> Template:
|
|||
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
|
||||
return _parse_yaml_template(name, response.text)
|
||||
|
||||
if ":" in name:
|
||||
if has_plugin_prefix(name):
|
||||
prefix, rest = name.split(":", 1)
|
||||
loaders = get_template_loaders()
|
||||
if prefix not in loaders:
|
||||
|
|
|
|||
|
|
@ -23,3 +23,8 @@ def register_embedding_models(register):
|
|||
@hookspec
|
||||
def register_template_loaders(register):
|
||||
"Register additional template loaders with prefixes"
|
||||
|
||||
|
||||
@hookspec
|
||||
def register_fragment_loaders(register):
|
||||
"Register additional fragment loaders with prefixes"
|
||||
|
|
|
|||
|
|
@ -303,6 +303,7 @@ def m015_fragments_tables(db):
|
|||
pk=("response_id", "fragment_id"),
|
||||
)
|
||||
|
||||
|
||||
@migration
|
||||
def m016_fragments_table_pks(db):
|
||||
# The same fragment can be attached to a response multiple times
|
||||
|
|
|
|||
28
llm/utils.py
28
llm/utils.py
|
|
@ -14,21 +14,15 @@ MIME_TYPE_FIXES = {
|
|||
}
|
||||
|
||||
|
||||
class FragmentString(str):
|
||||
def __new__(cls, content, source):
|
||||
# We need to use __new__ since str is immutable
|
||||
instance = super().__new__(cls, content)
|
||||
return instance
|
||||
class Fragment(str):
|
||||
def __new__(cls, content, *args, **kwargs):
|
||||
# For immutable classes like str, __new__ creates the string object
|
||||
return super().__new__(cls, content)
|
||||
|
||||
def __init__(self, content, source):
|
||||
def __init__(self, content, source=""):
|
||||
# Initialize our custom attributes
|
||||
self.source = source
|
||||
|
||||
def __str__(self):
|
||||
return super().__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__()
|
||||
|
||||
def id(self):
|
||||
return hashlib.sha256(self.encode("utf-8")).hexdigest()
|
||||
|
||||
|
|
@ -465,7 +459,7 @@ def ensure_fragment(db, content):
|
|||
"""
|
||||
hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
source = None
|
||||
if isinstance(content, FragmentString):
|
||||
if isinstance(content, Fragment):
|
||||
source = content.source
|
||||
with db.conn:
|
||||
db.execute(sql, {"hash": hash, "content": content, "source": source})
|
||||
|
|
@ -501,3 +495,11 @@ def maybe_fenced_code(content: str) -> str:
|
|||
+ "`" * num_backticks
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
_plugin_prefix_re = re.compile(r"^[a-zA-Z0-9_-]+:")
|
||||
|
||||
|
||||
def has_plugin_prefix(value: str) -> bool:
|
||||
"Check if value starts with alphanumeric prefix followed by a colon"
|
||||
return bool(_plugin_prefix_re.match(value))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from click.testing import CliRunner
|
||||
from llm.cli import cli
|
||||
from llm.migrations import migrate
|
||||
from llm.utils import FragmentString
|
||||
from llm import Fragment
|
||||
from ulid import ULID
|
||||
import datetime
|
||||
import json
|
||||
|
|
@ -440,7 +440,7 @@ def fragments_fixture(user_path):
|
|||
# Create fragments
|
||||
for i in range(1, 6):
|
||||
content = f"This is fragment {i}" * (100 if i == 5 else 1)
|
||||
fragment = FragmentString(content, "fragment")
|
||||
fragment = Fragment(content, "fragment")
|
||||
db["fragments"].insert(
|
||||
{
|
||||
"id": i,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from click.testing import CliRunner
|
|||
import click
|
||||
import importlib
|
||||
import llm
|
||||
from llm import cli, hookimpl, plugins, get_template_loaders
|
||||
from llm import cli, hookimpl, plugins, get_template_loaders, get_fragment_loaders
|
||||
|
||||
|
||||
def test_register_commands():
|
||||
|
|
@ -88,3 +88,52 @@ def test_register_template_loaders():
|
|||
finally:
|
||||
plugins.pm.unregister(name="TemplateLoadersPlugin")
|
||||
assert get_template_loaders() == {}
|
||||
|
||||
|
||||
def test_register_fragment_loaders(logs_db):
|
||||
assert get_fragment_loaders() == {}
|
||||
|
||||
def single_fragment(argument):
|
||||
return llm.Fragment("single", "single")
|
||||
|
||||
def three_fragments(argument):
|
||||
return [
|
||||
llm.Fragment(f"one:{argument}", "one"),
|
||||
llm.Fragment(f"two:{argument}", "two"),
|
||||
llm.Fragment(f"three:{argument}", "three"),
|
||||
]
|
||||
|
||||
class FragmentLoadersPlugin:
|
||||
__name__ = "FragmentLoadersPlugin"
|
||||
|
||||
@hookimpl
|
||||
def register_fragment_loaders(self, register):
|
||||
register("single", single_fragment)
|
||||
register("three", three_fragments)
|
||||
|
||||
try:
|
||||
plugins.pm.register(FragmentLoadersPlugin(), name="FragmentLoadersPlugin")
|
||||
loaders = get_fragment_loaders()
|
||||
assert loaders == {
|
||||
"single": single_fragment,
|
||||
"three": three_fragments,
|
||||
}
|
||||
|
||||
# Test the CLI command
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli.cli, ["-m", "echo", "-f", "three:x"], catch_exceptions=False
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
expected = "prompt:\n" "one:x\n" "two:x\n" "three:x\n"
|
||||
assert expected in result.output
|
||||
finally:
|
||||
plugins.pm.unregister(name="FragmentLoadersPlugin")
|
||||
assert get_fragment_loaders() == {}
|
||||
|
||||
# Let's check the database
|
||||
assert list(logs_db.query("select content, source from fragments")) == [
|
||||
{"content": "one:x", "source": "one"},
|
||||
{"content": "two:x", "source": "two"},
|
||||
{"content": "three:x", "source": "three"},
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue