From f740a5cbbde24e5eb5ad464964d8fe72bd5f28eb Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 5 Apr 2025 17:22:37 -0700 Subject: [PATCH] Fragments (#859) * WIP fragments: schema plus reading but not yet writing, refs #617 * Unique index on fragments.alias, refs #617 * Fragments are now persisted, added basic CLI commands * Fragment aliases work now, refs #617 * Improved help for -f/--fragment * Support fragment hash as well * Documentation for fragments * Better non-JSON display of llm fragments list * llm fragments -q search option * _truncate_string is now truncate_string * Use condense_json to avoid duplicate data in JSON in DB, refs #617 * Follow up to 3 redirects for fragments * Python API docs for fragments= and system_fragments= * Fragment aliases cannot contain a : - this is to ensure we can add custom fragment loaders later on, refs https://github.com/simonw/llm/pull/859#issuecomment-2761534692 * Use template fragments when running prompts * llm fragments show command plus llm fragments group tests * Tests for fragments family of commands * Test for --save with fragments * Add fragments tables to docs/logging.md * Slightly better llm fragments --help * Handle fragments in past conversations correctly * Hint at llm prompt --help in llm --help, closes #868 * llm logs -f filter plus show fragments in llm logs --json * Include prompt and system fragments in llm logs -s * llm logs markdown fragment output and tests, refs #617 --- docs/help.md | 83 ++++++++ docs/logging.md | 30 ++- docs/python-api.md | 26 +++ docs/usage.md | 50 ++++- llm/cli.py | 373 ++++++++++++++++++++++++++++++++++-- llm/migrations.py | 47 +++++ llm/models.py | 130 +++++++++++-- llm/templates.py | 5 +- llm/utils.py | 33 ++++ setup.py | 1 + tests/test_fragments_cli.py | 51 +++++ tests/test_llm.py | 12 +- tests/test_llm_logs.py | 357 +++++++++++++++++++++++++++++++++- tests/test_templates.py | 9 + 14 files changed, 1165 insertions(+), 42 deletions(-) create mode 100644 tests/test_fragments_cli.py diff --git a/docs/help.md b/docs/help.md index c426e32..fbd1af7 100644 --- a/docs/help.md +++ b/docs/help.md @@ -75,6 +75,7 @@ Commands: embed Embed text and store or return the result embed-models Manage available embedding models embed-multi Store embeddings for multiple strings at once in the... + fragments Manage fragments that are stored in the database install Install packages from PyPI into the same environment as LLM keys Manage stored API keys for different models logs Tools for exploring logged prompts and responses @@ -126,6 +127,9 @@ Options: -o, --option ... key/value options for the model --schema TEXT JSON schema, filepath or ID --schema-multi TEXT JSON schema to use for multiple results + -f, --fragment TEXT Fragment (alias, URL, hash or file path) to + add to the prompt + --sf, --system-fragment TEXT Fragment to add to system prompt -t, --template TEXT Template to use -p, --param ... Parameters for template --no-stream Do not stream output @@ -308,6 +312,7 @@ Options: -d, --database FILE Path to log database -m, --model TEXT Filter by model or model alias -q, --query TEXT Search for logs matching this string + -f, --fragment TEXT Filter for prompts using these fragments --schema TEXT JSON schema, filepath or ID --schema-multi TEXT JSON schema used for multiple results --data Output newline-delimited JSON data for schema @@ -655,6 +660,84 @@ Options: --help Show this message and exit. ``` +(help-fragments)= +### llm fragments --help +``` +Usage: llm fragments [OPTIONS] COMMAND [ARGS]... + + Manage fragments that are stored in the database + + Fragments are reusable snippets of text that are shared across multiple + prompts. + +Options: + --help Show this message and exit. + +Commands: + list* List current fragments + remove Remove a fragment alias + set Set an alias for a fragment + show Display the fragment stored under an alias or hash +``` + +(help-fragments-list)= +#### llm fragments list --help +``` +Usage: llm fragments list [OPTIONS] + + List current fragments + +Options: + -q, --query TEXT Search for fragments matching these strings + --json Output as JSON + --help Show this message and exit. +``` + +(help-fragments-set)= +#### llm fragments set --help +``` +Usage: llm fragments set [OPTIONS] ALIAS FRAGMENT + + Set an alias for a fragment + + Accepts an alias and a file path, URL, hash or '-' for stdin + + Example usage: + + llm fragments set mydocs ./docs.md + +Options: + --help Show this message and exit. +``` + +(help-fragments-show)= +#### llm fragments show --help +``` +Usage: llm fragments show [OPTIONS] ALIAS_OR_HASH + + Display the fragment stored under an alias or hash + + llm fragments show mydocs + +Options: + --help Show this message and exit. +``` + +(help-fragments-remove)= +#### llm fragments remove --help +``` +Usage: llm fragments remove [OPTIONS] ALIAS + + Remove a fragment alias + + Example usage: + + llm fragments remove docs + +Options: + --help Show this message and exit. +``` + (help-plugins)= ### llm plugins --help ``` diff --git a/docs/logging.md b/docs/logging.md index d037fa5..66c878e 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -236,7 +236,10 @@ def cleanup_sql(sql): return first_line + '(\n ' + ',\n '.join(columns) + '\n);' cog.out("```sql\n") -for table in ("conversations", "schemas", "responses", "responses_fts", "attachments", "prompt_attachments"): +for table in ( + "conversations", "schemas", "responses", "responses_fts", "attachments", "prompt_attachments", + "fragments", "fragment_aliases", "prompt_fragments", "system_fragments" +): schema = db[table].schema cog.out(format(cleanup_sql(schema))) cog.out("\n") @@ -288,6 +291,31 @@ CREATE TABLE [prompt_attachments] ( PRIMARY KEY ([response_id], [attachment_id]) ); +CREATE TABLE [fragments] ( + [id] INTEGER PRIMARY KEY, + [hash] TEXT, + [content] TEXT, + [datetime_utc] TEXT, + [source] TEXT +); +CREATE TABLE [fragment_aliases] ( + [alias] TEXT PRIMARY KEY, + [fragment_id] INTEGER REFERENCES [fragments]([id]) +); +CREATE TABLE [prompt_fragments] ( + [response_id] TEXT REFERENCES [responses]([id]), + [fragment_id] INTEGER REFERENCES [fragments]([id]), + [order] INTEGER, + PRIMARY KEY ([response_id], + [fragment_id]) +); +CREATE TABLE [system_fragments] ( + [response_id] TEXT REFERENCES [responses]([id]), + [fragment_id] INTEGER REFERENCES [fragments]([id]), + [order] INTEGER, + PRIMARY KEY ([response_id], + [fragment_id]) +); ``` `responses_fts` configures [SQLite full-text search](https://www.sqlite.org/fts5.html) against the `prompt` and `response` columns in the `responses` table. diff --git a/docs/python-api.md b/docs/python-api.md index d055141..aab45e3 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -142,6 +142,32 @@ print(model.prompt( schema=llm.schema_dsl("name, age int, bio", multi=True) )) ``` + +(python-api-fragments)= + +### Fragments + +The {ref}`fragment system ` from the CLI tool can also be accessed from the Python API, by passing `fragments=` and/or `system_fragments=` lists of strings to the `prompt()` method: + +```python +response = model.prompt( + "What do these documents say about dogs?", + fragments=[ + open("dogs1.txt").read(), + open("dogs2.txt").read(), + ], + system_fragments=[ + "You answer questions like Snoopy", + ] +) +``` +This mechanism has limited utility in Python, as you can also assemble the contents of these strings together into the `prompt=` and `system=` strings directly. + +Fragments become more interesting if you are working with LLM's mechanisms for storing prompts to a SQLite database, which are not yet part of the stable, documented Python API. + +Some model plugins may include features that take advantage of fragments, for example [llm-anthropic](https://github.com/simonw/llm-anthropic) aims to use them as part of a mechanism that taps into Claude's prompt caching system. + + (python-api-model-options)= ### Model options diff --git a/docs/usage.md b/docs/usage.md index 9d5cb4b..39b1a65 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -115,7 +115,6 @@ cat llm/utils.py | llm -t pytest ``` See {ref}`prompt templates ` for more. - (usage-extract-fenced-code)= ### Extracting fenced code blocks @@ -192,6 +191,55 @@ Be warned that different models may support different dialects of the JSON schem See {ref}`schemas-logs` for tips on using the `llm logs --schema X` command to access JSON objects you have previously logged using this option. +(usage-fragments)= +### Fragments + +You can use the `-f/--fragment` option to reference fragments of context that you would like to load into your prompt. Fragments can be specified as URLs, file paths or as aliases to previously saved fragments. + +Fragments are designed for running longer prompts. LLM {ref}`stores prompts in a database `, and the same prompt repeated many times can end up stored as multiple copies, wasting disk space. A fragment will be stored just once and referenced by all of the prompts that use it. + +The `-f` option can accept a path to a file on disk, a URL or the hash or alias of a previous fragment. + +For example, to ask a question about the `robots.txt` file on `llm.datasette.io`: +```bash +llm -f https://llm.datasette.io/robots.txt 'explain this' +``` +For a poem inspired by some Python code on disk: +```bash +llm -f cli.py 'a short snappy poem inspired by this code' +``` +You can use as many `-f` options as you like - the fragments will be concatenated together in the order you provided, with any additional prompt added at the end. + +Fragments can also be used for the system prompt using the `--sf/--system-fragment` option. If you have a file called `explain_code.txt` containing this: + +``` +Explain this code in detail. Include copies of the code quoted in the explanation. +``` +You can run it as the system prompt like this: +```bash +llm -f cli.py --sf explain_code.txt +``` + +You can use the `llm fragments set` command to load a fragment and give it an alias for use in future queries: +```bash +llm fragments set cli cli.py +# Then +llm -f cli 'explain this code' +``` +Use `llm fragments` to list all fragments that have been stored: +```bash +llm fragments +``` +You can search by passing one or more `-q X` search strings. This will return results matching all of those strings, across the source, hash, aliases and content: +```bash +llm fragments -q pytest -q asyncio +``` + +The `llm fragments remove` command removes an alias. It does not delete the fragment record itself as those are linked to previous prompts and responses and cannot be deleted independently of them. +```bash +llm fragments remove cli +``` + (usage-conversation)= ### Continuing a conversation diff --git a/llm/cli.py b/llm/cli.py index 3f90b09..a168eeb 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -39,8 +39,10 @@ from llm.models import _BaseConversation from .migrations import migrate from .plugins import pm, load_plugins from .utils import ( + ensure_fragment, extract_fenced_code_block, find_unused_key, + FragmentString, make_schema_id, mimetype_from_path, mimetype_from_string, @@ -56,6 +58,7 @@ import base64 import httpx import pathlib import pydantic +import re import readline from runpy import run_module import shutil @@ -63,7 +66,7 @@ import sqlite_utils from sqlite_utils.utils import rows_from_file, Format import sys import textwrap -from typing import cast, Optional, Iterable, Union, Tuple, Any +from typing import cast, Optional, Iterable, List, Union, Tuple, Any import warnings import yaml @@ -72,6 +75,66 @@ warnings.simplefilter("ignore", ResourceWarning) DEFAULT_TEMPLATE = "prompt: " +class FragmentNotFound(Exception): + pass + + +def validate_fragment_alias(ctx, param, value): + if not re.match(r"^[a-zA-Z0-9_-]+$", value): + raise click.BadParameter("Fragment alias must be alphanumeric") + return value + + +def resolve_fragments( + db: sqlite_utils.Database, fragments: Iterable[str] +) -> List[FragmentString]: + """ + Resolve fragments into a list of (content, source) tuples + """ + + def _load_by_alias(fragment): + rows = list( + db.query( + """ + select content, source from fragments + left join fragment_aliases on fragments.id = fragment_aliases.fragment_id + where alias = :alias or hash = :alias limit 1 + """, + {"alias": fragment}, + ) + ) + if rows: + row = rows[0] + return row["content"], row["source"] + return None, None + + # These can be URLs or paths + resolved = [] + for fragment in fragments: + if fragment.startswith("http://") or fragment.startswith("https://"): + client = httpx.Client(follow_redirects=True, max_redirects=3) + response = client.get(fragment) + response.raise_for_status() + resolved.append(FragmentString(response.text, fragment)) + elif fragment == "-": + resolved.append(FragmentString(sys.stdin.read(), "-")) + else: + # Try from the DB + content, source = _load_by_alias(fragment) + if content is not None: + resolved.append(FragmentString(content, source)) + else: + # Now try path + path = pathlib.Path(fragment) + if path.exists(): + resolved.append( + FragmentString(path.read_text(), str(path.resolve())) + ) + else: + raise FragmentNotFound(f"Fragment '{fragment}' not found") + return resolved + + class AttachmentType(click.ParamType): name = "attachment" @@ -227,6 +290,20 @@ def cli(): "--schema-multi", help="JSON schema to use for multiple results", ) +@click.option( + "fragments", + "-f", + "--fragment", + multiple=True, + help="Fragment (alias, URL, hash or file path) to add to the prompt", +) +@click.option( + "system_fragments", + "--sf", + "--system-fragment", + multiple=True, + help="Fragment to add to system prompt", +) @click.option("-t", "--template", help="Template to use") @click.option( "-p", @@ -275,6 +352,8 @@ def prompt( options, schema_input, schema_multi, + fragments, + system_fragments, template, param, no_stream, @@ -368,6 +447,7 @@ def prompt( and not attachments and not attachment_types and not schema + and not fragments ): # Hang waiting for input to stdin (unless --save) prompt = sys.stdin.read() @@ -408,6 +488,10 @@ def prompt( to_save["extract_last"] = True if schema: to_save["schema_object"] = schema + if fragments: + to_save["fragments"] = list(fragments) + if system_fragments: + to_save["system_fragments"] = list(system_fragments) if options: # Need to validate and convert their types first model = get_model(model_id or get_default_model()) @@ -441,6 +525,11 @@ def prompt( raise click.ClickException(str(ex)) extract = template_obj.extract extract_last = template_obj.extract_last + # Combine with template fragments/system_fragments + if template_obj.fragments: + fragments = [*template_obj.fragments, *fragments] + if template_obj.system_fragments: + system_fragments = [*template_obj.system_fragments, *system_fragments] if template_obj.schema_object: schema = template_obj.schema_object input_ = "" @@ -528,6 +617,12 @@ def prompt( prompt = read_prompt() response = None + try: + fragments = resolve_fragments(db, fragments) + system_fragments = resolve_fragments(db, system_fragments) + except FragmentNotFound as ex: + raise click.ClickException(str(ex)) + prompt_method = model.prompt if conversation: prompt_method = conversation.prompt @@ -542,6 +637,8 @@ def prompt( attachments=resolved_attachments, system=system, schema=schema, + fragments=fragments, + system_fragments=system_fragments, **kwargs, ) async for chunk in response: @@ -551,9 +648,11 @@ def prompt( else: response = prompt_method( prompt, + fragments=fragments, attachments=resolved_attachments, - system=system, schema=schema, + system=system, + system_fragments=system_fragments, **kwargs, ) text = await response.text() @@ -568,9 +667,11 @@ def prompt( else: response = prompt_method( prompt, + fragments=fragments, attachments=resolved_attachments, system=system, schema=schema, + system_fragments=system_fragments, **kwargs, ) if should_stream: @@ -1008,6 +1109,13 @@ order by prompt_attachments."order" ) @click.option("-m", "--model", help="Filter by model or model alias") @click.option("-q", "--query", help="Search for logs matching this string") +@click.option( + "fragments", + "--fragment", + "-f", + help="Filter for prompts using these fragments", + multiple=True, +) @schema_option @click.option( "--schema-multi", @@ -1063,6 +1171,7 @@ def logs_list( database, model, query, + fragments, schema_input, schema_multi, data, @@ -1150,6 +1259,13 @@ def logs_list( "extra_where": "", } where_bits = [] + sql_params = { + "model": model_id, + "query": query, + "conversation_id": conversation_id, + "id_gt": id_gt, + "id_gte": id_gte, + } if model_id: where_bits.append("responses.model = :model") if conversation_id: @@ -1158,29 +1274,38 @@ def logs_list( where_bits.append("responses.id > :id_gt") if id_gte: where_bits.append("responses.id >= :id_gte") + if fragments: + frags = ", ".join(f":f{i}" for i in range(len(fragments))) + response_ids_sql = f""" + select response_id from prompt_fragments + where fragment_id in ( + select fragments.id from fragments + where hash in ({frags}) + or fragments.id in (select fragment_id from fragment_aliases where alias in ({frags})) + ) + union + select response_id from system_fragments + where fragment_id in ( + select fragments.id from fragments + where hash in ({frags}) + or fragments.id in (select fragment_id from fragment_aliases where alias in ({frags})) + ) + """ + where_bits.append(f"responses.id in ({response_ids_sql})") + for i, fragment in enumerate(fragments): + sql_params["f{}".format(i)] = fragment schema_id = None if schema: schema_id = make_schema_id(schema)[0] where_bits.append("responses.schema_id = :schema_id") + sql_params["schema_id"] = schema_id if where_bits: where_ = " and " if query else " where " sql_format["extra_where"] = where_ + " and ".join(where_bits) final_sql = sql.format(**sql_format) - rows = list( - db.query( - final_sql, - { - "model": model_id, - "query": query, - "conversation_id": conversation_id, - "schema_id": schema_id, - "id_gt": id_gt, - "id_gte": id_gte, - }, - ) - ) + rows = list(db.query(final_sql, sql_params)) # Reverse the order - we do this because we 'order by id desc limit 3' to get the # 3 most recent results, but we still want to display them in chronological order @@ -1195,6 +1320,36 @@ def logs_list( for attachment in attachments: attachments_by_id.setdefault(attachment["response_id"], []).append(attachment) + FRAGMENTS_SQL = """ + select + {table}.response_id, + fragments.hash, + fragments.id as fragment_id, + fragments.content, + ( + select json_group_array(fragment_aliases.alias) + from fragment_aliases + where fragment_aliases.fragment_id = fragments.id + ) as aliases + from {table} + join fragments on {table}.fragment_id = fragments.id + where {table}.response_id in ({placeholders}) + order by {table}."order" + """ + + # Fetch any prompt or system prompt fragments + prompt_fragments_by_id = {} + system_fragments_by_id = {} + for table, dictionary in ( + ("prompt_fragments", prompt_fragments_by_id), + ("system_fragments", system_fragments_by_id), + ): + for fragment in db.query( + FRAGMENTS_SQL.format(placeholders=",".join("?" * len(ids)), table=table), + ids, + ): + dictionary.setdefault(fragment["response_id"], []).append(fragment) + if data or data_array or data_key or data_ids: # Special case for --data to output valid JSON to_output = [] @@ -1226,6 +1381,20 @@ def logs_list( if truncate: row["prompt"] = truncate_string(row["prompt"] or "") row["response"] = truncate_string(row["response"] or "") + # Add prompt and system fragments + for key in ("prompt_fragments", "system_fragments"): + row[key] = [ + { + "hash": fragment["hash"], + "content": truncate_string(fragment["content"]), + "aliases": json.loads(fragment["aliases"]), + } + for fragment in ( + prompt_fragments_by_id.get(row["id"], []) + if key == "prompt_fragments" + else system_fragments_by_id.get(row["id"], []) + ) + ] # Either decode or remove all JSON keys keys = list(row.keys()) for key in keys: @@ -1290,6 +1459,8 @@ def logs_list( details["url"] = attachment["url"] items.append(details) obj["attachments"] = items + for key in ("prompt_fragments", "system_fragments"): + obj[key] = [fragment["hash"] for fragment in row[key]] if usage and (row["input_tokens"] or row["output_tokens"]): usage_details = { "input": row["input_tokens"], @@ -1300,6 +1471,7 @@ def logs_list( obj["usage"] = usage_details click.echo(yaml.dump([obj], sort_keys=False).strip()) continue + # Not short, output Markdown click.echo( "# {}{}\n{}".format( row["datetime_utc"].split(".")[0], @@ -1321,10 +1493,32 @@ def logs_list( if conversation_id: should_show_conversation = False click.echo("## Prompt\n\n{}".format(row["prompt"] or "-- none --")) + if row["prompt_fragments"]: + click.echo( + "\n### Prompt fragments\n\n{}".format( + "\n".join( + [ + "- {}".format(fragment["hash"]) + for fragment in row["prompt_fragments"] + ] + ) + ) + ) if row["system"] != current_system: if row["system"] is not None: click.echo("\n## System\n\n{}".format(row["system"])) current_system = row["system"] + if row["system_fragments"]: + click.echo( + "\n### System fragments\n\n{}".format( + "\n".join( + [ + "- {}".format(fragment["hash"]) + for fragment in row["system_fragments"] + ] + ) + ) + ) if row["schema_json"]: click.echo( "\n## Schema\n\n```json\n{}\n```".format( @@ -1819,6 +2013,155 @@ def aliases_path(): click.echo(user_dir() / "aliases.json") +@cli.group( + cls=DefaultGroup, + default="list", + default_if_no_args=True, +) +def fragments(): + """ + Manage fragments that are stored in the database + + Fragments are reusable snippets of text that are shared across multiple prompts. + """ + + +@fragments.command(name="list") +@click.option( + "queries", + "-q", + "--query", + multiple=True, + help="Search for fragments matching these strings", +) +@click.option("json_", "--json", is_flag=True, help="Output as JSON") +def fragments_list(queries, json_): + "List current fragments" + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + params = {} + param_count = 0 + where_bits = [] + for q in queries: + param_count += 1 + p = f"p{param_count}" + params[p] = q + where_bits.append( + f""" + (fragments.hash = :{p} or fragment_aliases.alias = :{p} + or fragments.source like '%' || :{p} || '%' + or fragments.content like '%' || :{p} || '%') + """ + ) + where = "\n and\n ".join(where_bits) + if where: + where = " where " + where + sql = """ + select + fragments.hash, + json_group_array(fragment_aliases.alias) filter ( + where + fragment_aliases.alias is not null + ) as aliases, + fragments.datetime_utc, + fragments.source, + fragments.content + from + fragments + left join + fragment_aliases on fragment_aliases.fragment_id = fragments.id + {where} + group by + fragments.id, fragments.hash, fragments.content, fragments.datetime_utc, fragments.source; + """.format( + where=where + ) + results = list(db.query(sql, params)) + for result in results: + result["aliases"] = json.loads(result["aliases"]) + if json_: + click.echo(json.dumps(results, indent=4)) + else: + yaml.add_representer( + str, + lambda dumper, data: dumper.represent_scalar( + "tag:yaml.org,2002:str", data, style="|" if "\n" in data else None + ), + ) + for result in results: + result["content"] = truncate_string(result["content"]) + click.echo(yaml.dump([result], sort_keys=False, width=sys.maxsize).strip()) + + +@fragments.command(name="set") +@click.argument("alias", callback=validate_fragment_alias) +@click.argument("fragment") +def fragments_set(alias, fragment): + """ + Set an alias for a fragment + + Accepts an alias and a file path, URL, hash or '-' for stdin + + Example usage: + + \b + llm fragments set mydocs ./docs.md + """ + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + try: + resolved = resolve_fragments(db, [fragment])[0] + except FragmentNotFound as ex: + raise click.ClickException(str(ex)) + migrate(db) + alias_sql = """ + insert into fragment_aliases (alias, fragment_id) + values (:alias, :fragment_id) + on conflict(alias) do update set + fragment_id = excluded.fragment_id; + """ + with db.conn: + fragment_id = ensure_fragment(db, resolved) + db.conn.execute(alias_sql, {"alias": alias, "fragment_id": fragment_id}) + + +@fragments.command(name="show") +@click.argument("alias_or_hash") +def fragments_show(alias_or_hash): + """ + Display the fragment stored under an alias or hash + + \b + llm fragments show mydocs + """ + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + try: + resolved = resolve_fragments(db, [alias_or_hash])[0] + except FragmentNotFound as ex: + raise click.ClickException(str(ex)) + click.echo(resolved) + + +@fragments.command(name="remove") +@click.argument("alias", callback=validate_fragment_alias) +def fragments_remove(alias): + """ + Remove a fragment alias + + Example usage: + + \b + llm fragments remove docs + """ + db = sqlite_utils.Database(logs_db_path()) + migrate(db) + with db.conn: + db.conn.execute( + "delete from fragment_aliases where alias = :alias", {"alias": alias} + ) + + @cli.command(name="plugins") @click.option("--all", help="Include built-in default plugins", is_flag=True) def plugins_list(all): diff --git a/llm/migrations.py b/llm/migrations.py index 0b93188..dfbe446 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -255,3 +255,50 @@ def m014_schemas(db): db["responses"].enable_fts( ["prompt", "response"], create_triggers=True, replace=True ) + + +@migration +def m015_fragments_tables(db): + db["fragments"].create( + { + "id": int, + "hash": str, + "content": str, + "datetime_utc": str, + "source": str, + }, + pk="id", + ) + db["fragments"].create_index(["hash"], unique=True) + db["fragment_aliases"].create( + { + "alias": str, + "fragment_id": int, + }, + foreign_keys=(("fragment_id", "fragments", "id"),), + pk="alias", + ) + db["prompt_fragments"].create( + { + "response_id": str, + "fragment_id": int, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("fragment_id", "fragments", "id"), + ), + pk=("response_id", "fragment_id"), + ) + db["system_fragments"].create( + { + "response_id": str, + "fragment_id": int, + "order": int, + }, + foreign_keys=( + ("response_id", "responses", "id"), + ("fragment_id", "fragments", "id"), + ), + pk=("response_id", "fragment_id"), + ) diff --git a/llm/models.py b/llm/models.py index a008a3e..5822bc9 100644 --- a/llm/models.py +++ b/llm/models.py @@ -1,5 +1,6 @@ import asyncio import base64 +from condense_json import condense_json from dataclasses import dataclass, field import datetime from .errors import NeedsKeyException @@ -21,10 +22,11 @@ from typing import ( Union, ) from .utils import ( + ensure_fragment, + make_schema_id, mimetype_from_path, mimetype_from_string, token_usage_string, - make_schema_id, ) from abc import ABC, abstractmethod import json @@ -103,10 +105,12 @@ class Attachment: @dataclass class Prompt: - prompt: Optional[str] + _prompt: Optional[str] model: "Model" + fragments: Optional[List[str]] attachments: Optional[List[Attachment]] - system: Optional[str] + _system: Optional[str] + system_fragments: Optional[List[str]] prompt_json: Optional[str] schema: Optional[Union[Dict, type[BaseModel]]] options: "Options" @@ -116,22 +120,39 @@ class Prompt: prompt, model, *, + fragments=None, attachments=None, system=None, + system_fragments=None, prompt_json=None, options=None, schema=None, ): - self.prompt = prompt + self._prompt = prompt self.model = model self.attachments = list(attachments or []) - self.system = system + self.fragments = fragments or [] + self._system = system + self.system_fragments = system_fragments or [] self.prompt_json = prompt_json if schema and not isinstance(schema, dict) and issubclass(schema, BaseModel): schema = schema.model_json_schema() self.schema = schema self.options = options or {} + @property + def prompt(self): + return "\n".join(self.fragments + ([self._prompt] if self._prompt else [])) + + @property + def system(self): + bits = [ + bit.strip() + for bit in (self.system_fragments + [self._system or ""]) + if bit.strip() + ] + return "\n\n".join(bits) + @dataclass class _BaseConversation: @@ -152,9 +173,11 @@ class Conversation(_BaseConversation): self, prompt: Optional[str] = None, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, schema: Optional[Union[dict, type[BaseModel]]] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, key: Optional[str] = None, **options, @@ -163,9 +186,11 @@ class Conversation(_BaseConversation): Prompt( prompt, model=self.model, + fragments=fragments, attachments=attachments, system=system, schema=schema, + system_fragments=system_fragments, options=self.model.Options(**options), ), self.model, @@ -196,9 +221,11 @@ class AsyncConversation(_BaseConversation): self, prompt: Optional[str] = None, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, schema: Optional[Union[dict, type[BaseModel]]] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, key: Optional[str] = None, **options, @@ -207,9 +234,11 @@ class AsyncConversation(_BaseConversation): Prompt( prompt, model=self.model, + fragments=fragments, attachments=attachments, system=system, schema=schema, + system_fragments=system_fragments, options=self.model.Options(**options), ), self.model, @@ -234,6 +263,26 @@ class AsyncConversation(_BaseConversation): return f"<{self.__class__.__name__}: {self.id} - {count} response{s}" +FRAGMENT_SQL = """ +select + 'prompt' as fragment_type, + fragments.content, + pf."order" as ord +from prompt_fragments pf +join fragments on pf.fragment_id = fragments.id +where pf.response_id = :response_id +union all +select + 'system' as fragment_type, + fragments.content, + sf."order" as ord +from system_fragments sf +join fragments on sf.fragment_id = fragments.id +where sf.response_id = :response_id +order by fragment_type desc, ord asc; +""" + + class _BaseResponse: """Base response class shared between sync and async responses""" @@ -296,27 +345,37 @@ class _BaseResponse: if row["schema_id"]: schema = json.loads(db["schemas"].get(row["schema_id"])["content"]) + all_fragments = list(db.query(FRAGMENT_SQL, {"response_id": row["id"]})) + fragments = [ + row["content"] for row in all_fragments if row["fragment_type"] == "prompt" + ] + system_fragments = [ + row["content"] for row in all_fragments if row["fragment_type"] == "system" + ] response = cls( model=model, prompt=Prompt( prompt=row["prompt"], model=model, + fragments=fragments, attachments=[], system=row["system"], schema=schema, + system_fragments=system_fragments, options=model.Options(**json.loads(row["options_json"])), ), stream=False, ) + prompt_json = json.loads(row["prompt_json"] or "null") response.id = row["id"] - response._prompt_json = json.loads(row["prompt_json"] or "null") + response._prompt_json = prompt_json response.response_json = json.loads(row["response_json"] or "null") response._done = True response._chunks = [row["response"]] # Attachments response.attachments = [ - Attachment.from_row(arow) - for arow in db.query( + Attachment.from_row(attachment_row) + for attachment_row in db.query( """ select attachments.* from attachments join prompt_attachments on attachments.id = prompt_attachments.attachment_id @@ -353,19 +412,55 @@ class _BaseResponse: db["schemas"].insert({"id": schema_id, "content": schema_json}, ignore=True) response_id = str(ULID()).lower() + replacements = {} + # Include replacements from previous responses + for previous_response in conversation.responses[:-1]: + for fragment in (previous_response.prompt.fragments or []) + ( + previous_response.prompt.system_fragments or [] + ): + fragment_id = ensure_fragment(db, fragment) + replacements[f"f:{fragment_id}"] = fragment + replacements[f"r:{previous_response.id}"] = ( + previous_response.text_or_raise() + ) + + for i, fragment in enumerate(self.prompt.fragments): + fragment_id = ensure_fragment(db, fragment) + replacements[f"f{fragment_id}"] = fragment + db["prompt_fragments"].insert( + { + "response_id": response_id, + "fragment_id": fragment_id, + "order": i, + }, + ) + for i, fragment in enumerate(self.prompt.system_fragments): + fragment_id = ensure_fragment(db, fragment) + replacements[f"f{fragment_id}"] = fragment + db["system_fragments"].insert( + { + "response_id": response_id, + "fragment_id": fragment_id, + "order": i, + }, + ) + + response_text = self.text_or_raise() + replacements[f"r:{response_id}"] = response_text + json_data = self.json() response = { "id": response_id, "model": self.model.model_id, - "prompt": self.prompt.prompt, - "system": self.prompt.system, - "prompt_json": self._prompt_json, + "prompt": self.prompt._prompt, + "system": self.prompt._system, + "prompt_json": condense_json(self._prompt_json, replacements), "options_json": { key: value for key, value in dict(self.prompt.options).items() if value is not None }, - "response": self.text_or_raise(), - "response_json": self.json(), + "response": response_text, + "response_json": condense_json(json_data, replacements), "conversation_id": conversation.id, "duration_ms": self.duration_ms(), "datetime_utc": self.datetime_utc(), @@ -377,6 +472,7 @@ class _BaseResponse: "schema_id": schema_id, } db["responses"].insert(response) + # Persist any attachments - loop through with index for index, attachment in enumerate(self.prompt.attachments): attachment_id = attachment.id() @@ -728,8 +824,10 @@ class _Model(_BaseModel): self, prompt: Optional[str] = None, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, schema: Optional[Union[dict, type[BaseModel]]] = None, **options, @@ -739,9 +837,11 @@ class _Model(_BaseModel): return Response( Prompt( prompt, + fragments=fragments, attachments=attachments, system=system, schema=schema, + system_fragments=system_fragments, model=self, options=self.Options(**options), ), @@ -784,9 +884,11 @@ class _AsyncModel(_BaseModel): self, prompt: Optional[str] = None, *, + fragments: Optional[List[str]] = None, attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, schema: Optional[Union[dict, type[BaseModel]]] = None, + system_fragments: Optional[List[str]] = None, stream: bool = True, **options, ) -> AsyncResponse: @@ -795,9 +897,11 @@ class _AsyncModel(_BaseModel): return AsyncResponse( Prompt( prompt, + fragments=fragments, attachments=attachments, system=system, schema=schema, + system_fragments=system_fragments, model=self, options=self.Options(**options), ), diff --git a/llm/templates.py b/llm/templates.py index 9d3495e..0544408 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -10,10 +10,11 @@ class Template(BaseModel): model: Optional[str] = None defaults: Optional[Dict[str, Any]] = None options: Optional[Dict[str, Any]] = None - # Should a fenced code block be extracted? - extract: Optional[bool] = None + extract: Optional[bool] = None # For extracting fenced code blocks extract_last: Optional[bool] = None schema_object: Optional[dict] = None + fragments: Optional[List[str]] = None + system_fragments: Optional[List[str]] = None model_config = ConfigDict(extra="forbid") diff --git a/llm/utils.py b/llm/utils.py index 686890b..18280d2 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -14,6 +14,22 @@ MIME_TYPE_FIXES = { } +class FragmentString(str): + def __new__(cls, content, source): + # We need to use __new__ since str is immutable + instance = super().__new__(cls, content) + return instance + + def __init__(self, content, source): + self.source = source + + def __str__(self): + return super().__str__() + + def __repr__(self): + return super().__repr__() + + def mimetype_from_string(content) -> Optional[str]: try: type_ = puremagic.from_string(content, mime=True) @@ -436,3 +452,20 @@ def truncate_string( else: # Fall back to simple truncation for very small max_length return text[: max_length - 3] + "..." + + +def ensure_fragment(db, content): + sql = """ + insert into fragments (hash, content, datetime_utc, source) + values (:hash, :content, datetime('now'), :source) + on conflict(hash) do nothing + """ + hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + source = None + if isinstance(content, FragmentString): + source = content.source + with db.conn: + db.execute(sql, {"hash": hash, "content": content, "source": source}) + return list( + db.query("select id from fragments where hash = :hash", {"hash": hash}) + )[0]["id"] diff --git a/setup.py b/setup.py index d5cdad1..d5d7ea8 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ setup( """, install_requires=[ "click", + "condense-json>=0.1.2", "openai>=1.55.3", "click-default-group>=1.2.3", "sqlite-utils>=3.37", diff --git a/tests/test_fragments_cli.py b/tests/test_fragments_cli.py new file mode 100644 index 0000000..a7d81dc --- /dev/null +++ b/tests/test_fragments_cli.py @@ -0,0 +1,51 @@ +from click.testing import CliRunner +from llm.cli import cli +import yaml + + +def test_fragments_set_show_remove(user_path): + runner = CliRunner() + with runner.isolated_filesystem(): + open("fragment1.txt", "w").write("Hello fragment 1") + assert ( + runner.invoke(cli, ["fragments", "set", "f1", "fragment1.txt"]).exit_code + == 0 + ) + result1 = runner.invoke(cli, ["fragments", "show", "f1"]) + assert result1.exit_code == 0 + assert result1.output == "Hello fragment 1\n" + + # Should be in the list now + def get_list(): + result2 = runner.invoke(cli, ["fragments", "list"]) + assert result2.exit_code == 0 + return yaml.safe_load(result2.output) + + loaded1 = get_list() + assert set(loaded1[0].keys()) == { + "aliases", + "content", + "datetime_utc", + "source", + "hash", + } + assert loaded1[0]["content"] == "Hello fragment 1" + assert loaded1[0]["aliases"] == ["f1"] + + # Show should work against both alias and hash + for key in ("f1", loaded1[0]["hash"]): + result3 = runner.invoke(cli, ["fragments", "show", key]) + assert result3.exit_code == 0 + assert result3.output == "Hello fragment 1\n" + + # But not for an invalid alias + result4 = runner.invoke(cli, ["fragments", "show", "badalias"]) + assert result4.exit_code == 1 + assert "Fragment 'badalias' not found" in result4.output + + # Remove that alias + assert runner.invoke(cli, ["fragments", "remove", "f1"]).exit_code == 0 + # Should still be in list but no alias + loaded2 = get_list() + assert loaded2[0]["aliases"] == [] + assert loaded2[0]["content"] == "Hello fragment 1" diff --git a/tests/test_llm.py b/tests/test_llm.py index 1208f72..0a66f20 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -117,8 +117,8 @@ def test_llm_default_prompt( "messages": [{"role": "user", "content": "three names \nfor a pet pelican"}] } assert json.loads(row["response_json"]) == { + "choices": [{"message": {"content": {"$": f"r:{row['id']}"}}}], "model": "gpt-4o-mini", - "choices": [{"message": {"content": "Bob, Alice, Eve"}}], } # Test "llm logs" @@ -143,7 +143,7 @@ def test_llm_default_prompt( "response": "Bob, Alice, Eve", "response_json": { "model": "gpt-4o-mini", - "choices": [{"message": {"content": "Bob, Alice, Eve"}}], + "choices": [{"message": {"content": {"$": f"r:{row['id']}"}}}], }, # This doesn't have the \n after three names: "conversation_name": "three names for a pet pelican", @@ -295,12 +295,10 @@ def test_openai_completion_system_prompt_error(): "--system", "system prompts not allowed", ], - catch_exceptions=False, ) assert result.exit_code == 1 assert ( - result.output - == "Error: System prompts are not supported for OpenAI completion models\n" + "System prompts are not supported for OpenAI completion models" in result.output ) @@ -328,7 +326,7 @@ def test_openai_completion_logprobs_stream( assert len(rows) == 1 row = rows[0] assert json.loads(row["response_json"]) == { - "content": "\n\nHi.", + "content": {"$": f'r:{row["id"]}'}, "logprobs": [ {"text": "\n\n", "top_logprobs": [{"\n\n": -0.6, "\n": -1.9}]}, {"text": "Hi", "top_logprobs": [{"Hi": -1.1, "Hello": -0.7}]}, @@ -381,7 +379,7 @@ def test_openai_completion_logprobs_nostream( {"!": -1.1, ".": -0.9}, ], }, - "text": "\n\nHi.", + "text": {"$": f"r:{row['id']}"}, } ], "created": 1695097747, diff --git a/tests/test_llm_logs.py b/tests/test_llm_logs.py index 2ad2b3b..59d530c 100644 --- a/tests/test_llm_logs.py +++ b/tests/test_llm_logs.py @@ -9,6 +9,7 @@ import re import sqlite_utils import sys import time +import yaml SINGLE_ID = "5843577700ba729bb14c327b30441885" @@ -219,17 +220,23 @@ def test_logs_short(log_path, arg, usage): " datetime: 'YYYY-MM-DDTHH:MM:SS'\n" " conversation: abc123\n" " system: system\n" - f" prompt: prompt\n{expected_usage}" + " prompt: prompt\n" + " prompt_fragments: []\n" + f" system_fragments: []\n{expected_usage}" "- model: davinci\n" " datetime: 'YYYY-MM-DDTHH:MM:SS'\n" " conversation: abc123\n" " system: system\n" - f" prompt: prompt\n{expected_usage}" + " prompt: prompt\n" + " prompt_fragments: []\n" + f" system_fragments: []\n{expected_usage}" "- model: davinci\n" " datetime: 'YYYY-MM-DDTHH:MM:SS'\n" " conversation: abc123\n" " system: system\n" - f" prompt: prompt\n{expected_usage}" + " prompt: prompt\n" + " prompt_fragments: []\n" + f" system_fragments: []\n{expected_usage}" ) assert output == expected @@ -418,3 +425,347 @@ def test_logs_schema_data_ids(schema_log_path): } for row in rows: assert set(row.keys()) == {"conversation_id", "response_id", "name"} + + +@pytest.fixture +def fragments_fixture(user_path): + log_path = str(user_path / "logs_fragments.db") + db = sqlite_utils.Database(log_path) + migrate(db) + start = datetime.datetime.now(datetime.timezone.utc) + # Replace everything from here on + + # Create fragments + for i in range(1, 5): + db["fragments"].insert( + { + "id": i, + "hash": f"hash{i}", + "content": f"This is fragment {i}", + "datetime_utc": start.isoformat(), + } + ) + + # Create some fragment aliases + db["fragment_aliases"].insert({"alias": "alias_1", "fragment_id": 3}) + db["fragment_aliases"].insert({"alias": "alias_3", "fragment_id": 4}) + + def make_response(name, prompt_fragment_ids=None, system_fragment_ids=None): + time.sleep(0.05) # To ensure ULIDs order predictably + response_id = str(ULID.from_timestamp(time.time())).lower() + db["responses"].insert( + { + "id": response_id, + "system": f"system: {name}", + "prompt": f"prompt: {name}", + "response": f"response: {name}", + "model": "davinci", + "datetime_utc": start.isoformat(), + "conversation_id": "abc123", + "input_tokens": 2, + "output_tokens": 5, + } + ) + # Link fragments to this response + for fragment_id in prompt_fragment_ids or []: + db["prompt_fragments"].insert( + {"response_id": response_id, "fragment_id": fragment_id} + ) + for fragment_id in system_fragment_ids or []: + db["system_fragments"].insert( + {"response_id": response_id, "fragment_id": fragment_id} + ) + return {name: response_id} + + collected = {} + collected.update(make_response("no_fragments")) + collected.update( + single_prompt_fragment_id=make_response("single_prompt_fragment", [1]) + ) + collected.update( + single_system_fragment_id=make_response("single_system_fragment", None, [2]) + ) + collected.update( + multi_prompt_fragment_id=make_response("multi_prompt_fragment", [1, 2]) + ) + collected.update( + multi_system_fragment_id=make_response("multi_system_fragment", None, [1, 2]) + ) + collected.update(both_fragments_id=make_response("both_fragments", [1, 2], [3, 4])) + collected.update( + single_prompt_fragment_with_alias_id=make_response( + "single_prompt_fragment_with_alias", [3], None + ) + ) + collected.update( + single_system_fragment_with_alias_id=make_response( + "single_system_fragment_with_alias", None, [4] + ) + ) + return {"path": log_path, "collected": collected} + + +@pytest.mark.parametrize( + "fragment_refs,expected", + ( + ( + ["hash1"], + [ + { + "name": "single_prompt_fragment", + "prompt_fragments": ["hash1"], + "system_fragments": [], + }, + { + "name": "multi_prompt_fragment", + "prompt_fragments": ["hash1", "hash2"], + "system_fragments": [], + }, + { + "name": "multi_system_fragment", + "prompt_fragments": [], + "system_fragments": ["hash1", "hash2"], + }, + { + "name": "both_fragments", + "prompt_fragments": ["hash1", "hash2"], + "system_fragments": ["hash3", "hash4"], + }, + ], + ), + ( + ["alias_3"], + [ + { + "name": "both_fragments", + "prompt_fragments": ["hash1", "hash2"], + "system_fragments": ["hash3", "hash4"], + }, + { + "name": "single_system_fragment_with_alias", + "prompt_fragments": [], + "system_fragments": ["hash4"], + }, + ], + ), + ), +) +def test_logs_fragments(fragments_fixture, fragment_refs, expected): + fragments_log_path = fragments_fixture["path"] + # fragments = fragments_fixture["collected"] + runner = CliRunner() + args = ["logs", "-d", fragments_log_path, "-n", "0"] + for ref in fragment_refs: + args.extend(["-f", ref]) + result = runner.invoke(cli, args + ["--json"], catch_exceptions=False) + assert result.exit_code == 0 + output = result.output + responses = json.loads(output) + # Re-shape that to same shape as expected + reshaped = [ + { + "name": response["prompt"].replace("prompt: ", ""), + "prompt_fragments": [ + fragment["hash"] for fragment in response["prompt_fragments"] + ], + "system_fragments": [ + fragment["hash"] for fragment in response["system_fragments"] + ], + } + for response in responses + ] + assert reshaped == expected + # Now test the `-s/--short` option: + result2 = runner.invoke(cli, args + ["-s"], catch_exceptions=False) + assert result2.exit_code == 0 + output2 = result2.output + loaded = yaml.safe_load(output2) + reshaped2 = [ + { + "name": item["prompt"].replace("prompt: ", ""), + "system_fragments": item["system_fragments"], + "prompt_fragments": item["prompt_fragments"], + } + for item in loaded + ] + assert reshaped2 == expected + + +def test_logs_fragments_markdown(fragments_fixture): + fragments_log_path = fragments_fixture["path"] + runner = CliRunner() + args = ["logs", "-d", fragments_log_path, "-n", "0"] + result = runner.invoke(cli, args, catch_exceptions=False) + assert result.exit_code == 0 + output = result.output + # Replace dates and IDs + output = datetime_re.sub("YYYY-MM-DDTHH:MM:SS", output) + output = id_re.sub("id: xxx", output) + assert ( + output.strip() + == """ +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: no_fragments + +## System + +system: no_fragments + +## Response + +response: no_fragments + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: single_prompt_fragment + +### Prompt fragments + +- hash1 + +## System + +system: single_prompt_fragment + +## Response + +response: single_prompt_fragment + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: single_system_fragment + +## System + +system: single_system_fragment + +### System fragments + +- hash2 + +## Response + +response: single_system_fragment + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: multi_prompt_fragment + +### Prompt fragments + +- hash1 +- hash2 + +## System + +system: multi_prompt_fragment + +## Response + +response: multi_prompt_fragment + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: multi_system_fragment + +## System + +system: multi_system_fragment + +### System fragments + +- hash1 +- hash2 + +## Response + +response: multi_system_fragment + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: both_fragments + +### Prompt fragments + +- hash1 +- hash2 + +## System + +system: both_fragments + +### System fragments + +- hash3 +- hash4 + +## Response + +response: both_fragments + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: single_prompt_fragment_with_alias + +### Prompt fragments + +- hash3 + +## System + +system: single_prompt_fragment_with_alias + +## Response + +response: single_prompt_fragment_with_alias + +# YYYY-MM-DDTHH:MM:SS conversation: abc123 id: xxx + +Model: **davinci** + +## Prompt + +prompt: single_system_fragment_with_alias + +## System + +system: single_system_fragment_with_alias + +### System fragments + +- hash4 + +## Response + +response: single_system_fragment_with_alias + """.strip() + ) diff --git a/tests/test_templates.py b/tests/test_templates.py index 1bc98e4..f16c102 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -117,6 +117,15 @@ def test_templates_list(templates_path, args): }, None, ), + # And fragments and system_fragments + ( + ["--fragment", "f1.txt", "--system-fragment", "https://example.com/f2.txt"], + { + "fragments": ["f1.txt"], + "system_fragments": ["https://example.com/f2.txt"], + }, + None, + ), ), ) def test_templates_prompt_save(templates_path, args, expected_prompt, expected_error):